use crate::constraints::ConstraintValue;
use crate::extraction::{
extract_all, CompiledExtractionRules, ExtractionError, ExtractionRule, ExtractionSource,
ExtractionTrace, RequestContext,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(not(target_arch = "wasm32"))]
use std::path::Path;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
pub version: String,
pub settings: GatewaySettings,
pub tools: HashMap<String, ToolConfig>,
pub routes: Vec<RouteConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewaySettings {
#[serde(default = "default_warrant_header")]
pub warrant_header: String,
#[serde(default = "default_pop_header")]
pub pop_header: String,
#[serde(default = "default_approval_header")]
pub approval_header: String,
#[serde(default = "default_clock_tolerance")]
pub clock_tolerance_secs: u64,
#[serde(default)]
pub trusted_roots: Vec<String>,
#[serde(default)]
pub debug_mode: bool,
}
fn default_warrant_header() -> String {
"X-Tenuo-Warrant".into()
}
fn default_pop_header() -> String {
"X-Tenuo-PoP".into()
}
fn default_approval_header() -> String {
"X-Tenuo-Approvals".into()
}
fn default_clock_tolerance() -> u64 {
crate::planes::DEFAULT_CLOCK_TOLERANCE_SECS as u64
}
impl Default for GatewaySettings {
fn default() -> Self {
Self {
warrant_header: default_warrant_header(),
pop_header: default_pop_header(),
approval_header: default_approval_header(),
clock_tolerance_secs: default_clock_tolerance(),
trusted_roots: Vec::new(),
debug_mode: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolConfig {
pub description: String,
pub constraints: HashMap<String, ExtractionRule>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteConfig {
pub pattern: String,
#[serde(default)]
pub method: Vec<String>,
pub tool: String,
#[serde(default)]
pub constraints: HashMap<String, ExtractionRule>,
}
#[derive(Debug)]
pub struct ExtractionResult {
pub constraints: HashMap<String, ConstraintValue>,
pub traces: Vec<ExtractionTrace>,
pub tool: String,
pub warrant_base64: Option<String>,
pub signature_base64: Option<String>,
pub approvals_base64: Vec<String>,
}
impl GatewayConfig {
pub fn from_yaml(yaml: &str) -> Result<Self, ConfigError> {
serde_yaml::from_str(yaml).map_err(ConfigError::YamlParse)
}
#[cfg(not(target_arch = "wasm32"))]
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path.as_ref())
.map_err(|e| ConfigError::FileRead(path.as_ref().display().to_string(), e))?;
Self::from_yaml(&content)
}
pub fn match_route(
&self,
method: &str,
path: &str,
) -> Option<(&RouteConfig, HashMap<String, String>)> {
for route in &self.routes {
if !route.method.is_empty()
&& !route.method.iter().any(|m| m.eq_ignore_ascii_case(method))
{
continue;
}
if let Some(params) = match_pattern(&route.pattern, path) {
return Some((route, params));
}
}
None
}
pub fn extract_constraints(
&self,
route: &RouteConfig,
ctx: &RequestContext,
) -> Result<ExtractionResult, ExtractionError> {
let tool_config = self.tools.get(&route.tool).ok_or_else(|| ExtractionError {
field: route.tool.clone(),
source: ExtractionSource::Literal,
path: String::new(),
hint: format!("Tool '{}' not defined in configuration", route.tool),
required: true,
})?;
let mut all_rules = tool_config.constraints.clone();
for (name, rule) in &route.constraints {
all_rules.insert(name.clone(), rule.clone());
}
let (constraints, traces) = extract_all(&all_rules, ctx)?;
Ok(ExtractionResult {
constraints,
traces,
tool: route.tool.clone(),
warrant_base64: None, signature_base64: None,
approvals_base64: Vec::new(),
})
}
pub fn validate(&self) -> Result<(), Vec<ConfigValidationError>> {
let mut errors = Vec::new();
for (i, route) in self.routes.iter().enumerate() {
if !self.tools.contains_key(&route.tool) {
errors.push(ConfigValidationError {
location: format!("routes[{}]", i),
message: format!("Tool '{}' is not defined", route.tool),
});
}
if let Err(msg) = validate_pattern(&route.pattern) {
errors.push(ConfigValidationError {
location: format!("routes[{}].pattern", i),
message: msg,
});
}
}
for (tool_name, tool_config) in &self.tools {
for (field_name, rule) in &tool_config.constraints {
if rule.from == ExtractionSource::Body && rule.path.is_empty() {
errors.push(ConfigValidationError {
location: format!("tools.{}.constraints.{}", tool_name, field_name),
message: "Body extraction requires a path".into(),
});
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
fn match_pattern(pattern: &str, path: &str) -> Option<HashMap<String, String>> {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let path_parts: Vec<&str> = path.split('/').collect();
let pattern_len = pattern_parts.len();
let path_len = path_parts.len();
if pattern_len != path_len {
if !(pattern_len == path_len + 1 && pattern_parts.last() == Some(&"")
|| path_len == pattern_len + 1 && path_parts.last() == Some(&""))
{
return None;
}
}
let mut params = HashMap::new();
for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
if pattern_part.starts_with('{') && pattern_part.ends_with('}') {
let name = &pattern_part[1..pattern_part.len() - 1];
params.insert(name.to_string(), path_part.to_string());
} else if pattern_part != path_part {
return None;
}
}
Some(params)
}
fn validate_pattern(pattern: &str) -> Result<(), String> {
let mut in_brace = false;
let mut brace_content = String::new();
for c in pattern.chars() {
match c {
'{' => {
if in_brace {
return Err("Nested braces not allowed".into());
}
in_brace = true;
brace_content.clear();
}
'}' => {
if !in_brace {
return Err("Unmatched closing brace".into());
}
if brace_content.is_empty() {
return Err("Empty parameter name".into());
}
in_brace = false;
}
_ if in_brace => {
brace_content.push(c);
}
_ => {}
}
}
if in_brace {
return Err("Unclosed brace".into());
}
Ok(())
}
#[derive(Debug)]
pub enum ConfigError {
YamlParse(serde_yaml::Error),
FileRead(String, std::io::Error),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::YamlParse(e) => write!(f, "YAML parse error: {}", e),
ConfigError::FileRead(path, e) => write!(f, "Failed to read {}: {}", path, e),
}
}
}
impl std::error::Error for ConfigError {}
#[derive(Debug)]
pub struct ConfigValidationError {
pub location: String,
pub message: String,
}
impl std::fmt::Display for ConfigValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.location, self.message)
}
}
#[derive(Debug, Clone)]
pub struct MethodMask {
standard: u8,
custom: std::collections::HashSet<String>,
all: bool,
}
impl MethodMask {
const GET: u8 = 1 << 0;
const POST: u8 = 1 << 1;
const PUT: u8 = 1 << 2;
const DELETE: u8 = 1 << 3;
const PATCH: u8 = 1 << 4;
const HEAD: u8 = 1 << 5;
const OPTIONS: u8 = 1 << 6;
pub fn all() -> Self {
Self {
standard: 0,
custom: std::collections::HashSet::new(),
all: true,
}
}
pub fn from_methods(methods: &[String]) -> Self {
if methods.is_empty() {
return Self::all();
}
let mut standard_mask = 0u8;
let mut custom_set = std::collections::HashSet::new();
for method in methods {
let method_upper = method.to_uppercase();
let bit = Self::method_bit(&method_upper);
if bit != 0 {
standard_mask |= bit;
} else {
custom_set.insert(method_upper);
}
}
Self {
standard: standard_mask,
custom: custom_set,
all: false,
}
}
fn method_bit(method: &str) -> u8 {
match method {
"GET" => Self::GET,
"POST" => Self::POST,
"PUT" => Self::PUT,
"DELETE" => Self::DELETE,
"PATCH" => Self::PATCH,
"HEAD" => Self::HEAD,
"OPTIONS" => Self::OPTIONS,
_ => 0,
}
}
#[inline]
pub fn matches(&self, method: &str) -> bool {
if self.all {
return true;
}
let method_upper = method.to_uppercase();
let bit = Self::method_bit(&method_upper);
if bit != 0 {
(self.standard & bit) != 0
} else {
self.custom.contains(&method_upper)
}
}
}
impl Default for MethodMask {
fn default() -> Self {
Self::all()
}
}
#[derive(Debug, Clone)]
pub struct CompiledRoute {
pub config: RouteConfig,
pub method_mask: MethodMask,
pub extraction_rules: CompiledExtractionRules,
pub tool: Arc<str>,
}
#[derive(Debug)]
pub struct RouteMatch<'a> {
pub route: &'a CompiledRoute,
pub path_params: HashMap<String, String>,
}
pub struct CompiledGatewayConfig {
pub settings: GatewaySettings,
router: matchit::Router<usize>,
routes: Vec<CompiledRoute>,
}
#[derive(Debug)]
pub enum CompileError {
InvalidPattern {
route_index: usize,
pattern: String,
error: String,
},
UndefinedTool { route_index: usize, tool: String },
RouterConflict {
route_index: usize,
pattern: String,
error: String,
},
}
impl std::fmt::Display for CompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompileError::InvalidPattern {
route_index,
pattern,
error,
} => {
write!(
f,
"Route {}: Invalid pattern '{}': {}",
route_index, pattern, error
)
}
CompileError::UndefinedTool { route_index, tool } => {
write!(f, "Route {}: Tool '{}' is not defined", route_index, tool)
}
CompileError::RouterConflict {
route_index,
pattern,
error,
} => {
write!(
f,
"Route {}: Pattern '{}' conflicts: {}",
route_index, pattern, error
)
}
}
}
}
impl std::error::Error for CompileError {}
impl CompiledGatewayConfig {
pub fn compile(config: GatewayConfig) -> Result<Self, CompileError> {
let mut router = matchit::Router::new();
let mut routes = Vec::with_capacity(config.routes.len());
for (i, route) in config.routes.into_iter().enumerate() {
let tool_config =
config
.tools
.get(&route.tool)
.ok_or_else(|| CompileError::UndefinedTool {
route_index: i,
tool: route.tool.clone(),
})?;
router
.insert(&route.pattern, i)
.map_err(|e| CompileError::RouterConflict {
route_index: i,
pattern: route.pattern.clone(),
error: e.to_string(),
})?;
let mut all_rules = tool_config.constraints.clone();
for (name, rule) in &route.constraints {
all_rules.insert(name.clone(), rule.clone());
}
let extraction_rules = CompiledExtractionRules::compile(all_rules);
let compiled = CompiledRoute {
method_mask: MethodMask::from_methods(&route.method),
tool: Arc::from(route.tool.as_str()),
extraction_rules,
config: route,
};
routes.push(compiled);
}
Ok(Self {
settings: config.settings,
router,
routes,
})
}
pub fn match_route(&self, method: &str, path: &str) -> Option<RouteMatch<'_>> {
let clean_path = path.split('?').next().unwrap_or(path);
let matched = self.router.at(clean_path).ok()?;
let route_idx = *matched.value;
let route = &self.routes[route_idx];
if !route.method_mask.matches(method) {
return None;
}
let mut path_params = HashMap::new();
for (key, value) in matched.params.iter() {
path_params.insert(key.to_string(), value.to_string());
}
Some(RouteMatch { route, path_params })
}
pub fn extract_constraints(
&self,
route_match: &RouteMatch<'_>,
ctx: &RequestContext,
) -> Result<ExtractionResult, ExtractionError> {
let mut full_ctx = ctx.clone();
full_ctx.path_params = route_match.path_params.clone();
let (constraints, traces) = route_match.route.extraction_rules.extract_all(&full_ctx)?;
Ok(ExtractionResult {
constraints,
traces,
tool: route_match.route.tool.to_string(),
warrant_base64: None, signature_base64: None,
approvals_base64: Vec::new(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
const SAMPLE_CONFIG: &str = r#"
version: "1"
settings:
warrant_header: "X-Tenuo-Warrant"
pop_header: "X-Tenuo-PoP"
clock_tolerance_secs: 30
trusted_roots:
- "f32e74b5b8569dc288db0109b7ec0d8eb3b4e5be7b07c647171d53fd31e7391f"
tools:
manage_infrastructure:
description: "Kubernetes cluster management"
constraints:
cluster:
from: path
path: "cluster"
required: true
action:
from: path
path: "action"
required: true
replicas:
from: body
path: "spec.replicas"
type: integer
cost:
from: body
path: "metadata.estimatedCost"
type: float
routes:
- pattern: "/api/v1/clusters/{cluster}/{action}"
method: ["POST", "PUT"]
tool: "manage_infrastructure"
"#;
#[test]
fn test_parse_config() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
assert_eq!(config.version, "1");
assert_eq!(config.settings.warrant_header, "X-Tenuo-Warrant");
assert!(config.tools.contains_key("manage_infrastructure"));
assert_eq!(config.routes.len(), 1);
}
#[test]
fn test_route_matching() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let result = config.match_route("POST", "/api/v1/clusters/staging-web/scale");
assert!(result.is_some());
let (route, params) = result.unwrap();
assert_eq!(route.tool, "manage_infrastructure");
assert_eq!(params.get("cluster"), Some(&"staging-web".to_string()));
assert_eq!(params.get("action"), Some(&"scale".to_string()));
}
#[test]
fn test_route_method_mismatch() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let result = config.match_route("GET", "/api/v1/clusters/staging-web/scale");
assert!(result.is_none());
}
#[test]
fn test_constraint_extraction() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let (route, path_params) = config
.match_route("POST", "/api/v1/clusters/staging-web/scale")
.unwrap();
let mut ctx = RequestContext::with_body(json!({
"spec": { "replicas": 5 },
"metadata": { "estimatedCost": 150.0 }
}));
ctx.path_params = path_params;
let result = config.extract_constraints(route, &ctx).unwrap();
assert_eq!(
result.constraints.get("cluster"),
Some(&ConstraintValue::String("staging-web".into()))
);
assert_eq!(
result.constraints.get("replicas"),
Some(&ConstraintValue::Integer(5))
);
assert_eq!(
result.constraints.get("cost"),
Some(&ConstraintValue::Float(150.0))
);
}
#[test]
fn test_pattern_matching() {
assert!(match_pattern("/api/{id}", "/api/123").is_some());
assert!(match_pattern("/api/{a}/{b}", "/api/x/y").is_some());
assert!(match_pattern("/api/static", "/api/static").is_some());
assert!(match_pattern("/api/{id}", "/api/123/extra").is_none());
assert!(match_pattern("/api/{id}", "/different/123").is_none());
}
#[test]
fn test_config_validation() {
let bad_config = r#"
version: "1"
settings: {}
tools: {}
routes:
- pattern: "/api/{}"
tool: "undefined_tool"
"#;
let config = GatewayConfig::from_yaml(bad_config).unwrap();
let errors = config.validate().unwrap_err();
assert!(errors.iter().any(|e| e.message.contains("undefined_tool")));
assert!(errors.iter().any(|e| e.message.contains("Empty parameter")));
}
#[test]
fn test_compiled_route_matching() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let compiled = CompiledGatewayConfig::compile(config).unwrap();
let result = compiled.match_route("POST", "/api/v1/clusters/staging-web/scale");
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route.tool.as_ref(), "manage_infrastructure");
assert_eq!(
route_match.path_params.get("cluster"),
Some(&"staging-web".to_string())
);
assert_eq!(
route_match.path_params.get("action"),
Some(&"scale".to_string())
);
}
#[test]
fn test_compiled_method_mask() {
let mask = MethodMask::from_methods(&["POST".to_string(), "PUT".to_string()]);
assert!(mask.matches("POST"));
assert!(mask.matches("PUT"));
assert!(mask.matches("post")); assert!(!mask.matches("GET"));
assert!(!mask.matches("DELETE"));
let all_mask = MethodMask::all();
assert!(all_mask.matches("GET"));
assert!(all_mask.matches("POST"));
assert!(all_mask.matches("DELETE"));
}
#[test]
fn test_custom_http_methods() {
let mask = MethodMask::from_methods(&[
"POST".to_string(),
"PURGE".to_string(),
"PROPFIND".to_string(),
]);
assert!(mask.matches("POST"));
assert!(!mask.matches("GET"));
assert!(mask.matches("PURGE"));
assert!(mask.matches("purge")); assert!(mask.matches("PROPFIND"));
assert!(mask.matches("propfind"));
assert!(!mask.matches("PATCH"));
assert!(!mask.matches("CUSTOM_METHOD"));
}
#[test]
fn test_method_mask_all() {
let all_mask = MethodMask::all();
assert!(all_mask.matches("GET"));
assert!(all_mask.matches("POST"));
assert!(all_mask.matches("PUT"));
assert!(all_mask.matches("DELETE"));
assert!(all_mask.matches("PATCH"));
assert!(all_mask.matches("HEAD"));
assert!(all_mask.matches("OPTIONS"));
assert!(all_mask.matches("PURGE"));
assert!(all_mask.matches("PROPFIND"));
assert!(all_mask.matches("CUSTOM_METHOD"));
}
#[test]
fn test_compiled_constraint_extraction() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let compiled = CompiledGatewayConfig::compile(config).unwrap();
let route_match = compiled
.match_route("POST", "/api/v1/clusters/staging-web/scale")
.unwrap();
let ctx = RequestContext::with_body(json!({
"spec": { "replicas": 5 },
"metadata": { "estimatedCost": 150.0 }
}));
let result = compiled.extract_constraints(&route_match, &ctx).unwrap();
assert_eq!(
result.constraints.get("cluster"),
Some(&ConstraintValue::String("staging-web".into()))
);
assert_eq!(
result.constraints.get("replicas"),
Some(&ConstraintValue::Integer(5))
);
assert_eq!(
result.constraints.get("cost"),
Some(&ConstraintValue::Float(150.0))
);
}
#[test]
fn test_compiled_method_mismatch() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let compiled = CompiledGatewayConfig::compile(config).unwrap();
let result = compiled.match_route("GET", "/api/v1/clusters/staging-web/scale");
assert!(result.is_none());
}
#[test]
fn test_compiled_path_with_query() {
let config = GatewayConfig::from_yaml(SAMPLE_CONFIG).unwrap();
let compiled = CompiledGatewayConfig::compile(config).unwrap();
let result = compiled.match_route("POST", "/api/v1/clusters/staging-web/scale?foo=bar");
assert!(result.is_some());
assert_eq!(
result.unwrap().path_params.get("cluster"),
Some(&"staging-web".to_string())
);
}
#[test]
fn test_matchit_directly() {
let mut router: matchit::Router<usize> = matchit::Router::new();
router.insert("/api/{cluster}/{action}", 0).unwrap();
let matched = router.at("/api/staging-web/scale");
assert!(matched.is_ok());
let m = matched.unwrap();
assert_eq!(*m.value, 0);
assert_eq!(m.params.get("cluster"), Some("staging-web"));
assert_eq!(m.params.get("action"), Some("scale"));
}
}