use crate::types::{RiskLevel, ValidationError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::HashSet;
pub fn resolve_server_id_from_env() -> Option<String> {
let candidate = std::env::var("PMCP_SERVER_ID")
.ok()
.or_else(|| std::env::var("AWS_LAMBDA_FUNCTION_NAME").ok())?;
if candidate.is_empty() {
None
} else {
Some(candidate)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperationEntry {
pub id: String,
pub category: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub path: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct OperationRegistry {
path_to_id: HashMap<String, String>,
path_to_category: HashMap<String, String>,
}
impl OperationRegistry {
pub fn from_entries(entries: &[OperationEntry]) -> Self {
let mut path_to_id = HashMap::with_capacity(entries.len());
let mut path_to_category = HashMap::with_capacity(entries.len());
for entry in entries {
if let Some(ref path) = entry.path {
path_to_id.insert(path.clone(), entry.id.clone());
if !entry.category.is_empty() {
path_to_category.insert(path.clone(), entry.category.clone());
}
}
}
Self {
path_to_id,
path_to_category,
}
}
pub fn lookup(&self, path: &str) -> Option<&str> {
self.path_to_id.get(path).map(|s| s.as_str())
}
pub fn lookup_category(&self, path: &str) -> Option<&str> {
self.path_to_category.get(path).map(|s| s.as_str())
}
pub fn is_empty(&self) -> bool {
self.path_to_id.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeModeConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub allow_mutations: bool,
#[serde(default)]
pub allowed_mutations: HashSet<String>,
#[serde(default)]
pub blocked_mutations: HashSet<String>,
#[serde(default)]
pub allow_introspection: bool,
#[serde(default)]
pub blocked_fields: HashSet<String>,
#[serde(default)]
pub allowed_queries: HashSet<String>,
#[serde(default)]
pub blocked_queries: HashSet<String>,
#[serde(default = "default_true")]
pub openapi_reads_enabled: bool,
#[serde(default)]
pub openapi_allow_writes: bool,
#[serde(default)]
pub openapi_allowed_writes: HashSet<String>,
#[serde(default)]
pub openapi_blocked_writes: HashSet<String>,
#[serde(default)]
pub openapi_allow_deletes: bool,
#[serde(default)]
pub openapi_allowed_deletes: HashSet<String>,
#[serde(default)]
pub openapi_blocked_paths: HashSet<String>,
#[serde(default)]
pub openapi_internal_blocked_fields: HashSet<String>,
#[serde(default)]
pub openapi_output_blocked_fields: HashSet<String>,
#[serde(default)]
pub openapi_require_output_declaration: bool,
#[serde(default = "default_true", alias = "reads_enabled")]
pub sql_reads_enabled: bool,
#[serde(default, alias = "allow_writes")]
pub sql_allow_writes: bool,
#[serde(default, alias = "allow_deletes")]
pub sql_allow_deletes: bool,
#[serde(default, alias = "allow_ddl")]
pub sql_allow_ddl: bool,
#[serde(default, alias = "allowed_statements")]
pub sql_allowed_statements: HashSet<String>,
#[serde(default, alias = "blocked_statements")]
pub sql_blocked_statements: HashSet<String>,
#[serde(default, alias = "blocked_tables")]
pub sql_blocked_tables: HashSet<String>,
#[serde(default, alias = "allowed_tables")]
pub sql_allowed_tables: HashSet<String>,
#[serde(default, alias = "blocked_columns")]
pub sql_blocked_columns: HashSet<String>,
#[serde(default = "default_sql_max_rows", alias = "max_rows")]
pub sql_max_rows: u64,
#[serde(default = "default_sql_max_joins", alias = "max_joins")]
pub sql_max_joins: u32,
#[serde(default = "default_true", alias = "require_where_on_writes")]
pub sql_require_where_on_writes: bool,
#[serde(default)]
pub action_tags: HashMap<String, String>,
#[serde(default = "default_max_depth")]
pub max_depth: u32,
#[serde(default = "default_max_field_count")]
pub max_field_count: u32,
#[serde(default = "default_max_cost")]
pub max_cost: u32,
#[serde(default)]
pub allowed_sensitive_categories: HashSet<String>,
#[serde(default = "default_token_ttl")]
pub token_ttl_seconds: i64,
#[serde(default = "default_auto_approve_levels")]
pub auto_approve_levels: Vec<RiskLevel>,
#[serde(default = "default_max_query_length")]
pub max_query_length: usize,
#[serde(default = "default_max_result_rows")]
pub max_result_rows: usize,
#[serde(default = "default_query_timeout")]
pub query_timeout_seconds: u32,
#[serde(default)]
pub server_id: Option<String>,
#[serde(default)]
pub sdk_operations: HashSet<String>,
#[serde(default)]
pub operations: Vec<OperationEntry>,
}
impl Default for CodeModeConfig {
fn default() -> Self {
Self {
enabled: false,
allow_mutations: false,
allowed_mutations: HashSet::new(),
blocked_mutations: HashSet::new(),
allow_introspection: false,
blocked_fields: HashSet::new(),
allowed_queries: HashSet::new(),
blocked_queries: HashSet::new(),
openapi_reads_enabled: true,
openapi_allow_writes: false,
openapi_allowed_writes: HashSet::new(),
openapi_blocked_writes: HashSet::new(),
openapi_allow_deletes: false,
openapi_allowed_deletes: HashSet::new(),
openapi_blocked_paths: HashSet::new(),
openapi_internal_blocked_fields: HashSet::new(),
openapi_output_blocked_fields: HashSet::new(),
openapi_require_output_declaration: false,
sql_reads_enabled: true,
sql_allow_writes: false,
sql_allow_deletes: false,
sql_allow_ddl: false,
sql_allowed_statements: HashSet::new(),
sql_blocked_statements: HashSet::new(),
sql_blocked_tables: HashSet::new(),
sql_allowed_tables: HashSet::new(),
sql_blocked_columns: HashSet::new(),
sql_max_rows: default_sql_max_rows(),
sql_max_joins: default_sql_max_joins(),
sql_require_where_on_writes: true,
action_tags: HashMap::new(),
max_depth: default_max_depth(),
max_field_count: default_max_field_count(),
max_cost: default_max_cost(),
allowed_sensitive_categories: HashSet::new(),
token_ttl_seconds: default_token_ttl(),
auto_approve_levels: default_auto_approve_levels(),
max_query_length: default_max_query_length(),
max_result_rows: default_max_result_rows(),
query_timeout_seconds: default_query_timeout(),
server_id: None,
sdk_operations: HashSet::new(),
operations: Vec::new(),
}
}
}
#[derive(Deserialize)]
struct TomlWrapper {
#[serde(default)]
code_mode: CodeModeConfig,
}
impl CodeModeConfig {
pub fn from_toml(toml_str: &str) -> Result<Self, toml::de::Error> {
let wrapper: TomlWrapper = toml::from_str(toml_str)?;
Ok(wrapper.code_mode)
}
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
pub fn is_sdk_mode(&self) -> bool {
!self.sdk_operations.is_empty()
}
pub fn should_auto_approve(&self, risk_level: RiskLevel) -> bool {
self.auto_approve_levels.contains(&risk_level)
}
pub fn server_id(&self) -> &str {
self.server_id.as_deref().unwrap_or("unknown")
}
pub fn resolve_server_id(&mut self) {
if self.server_id.is_some() {
return;
}
self.server_id = resolve_server_id_from_env();
}
pub fn require_server_id(&self) -> Result<&str, ValidationError> {
self.server_id.as_deref().ok_or_else(|| {
ValidationError::ConfigError(
"server_id is not set. Set it in config.toml, PMCP_SERVER_ID env var, \
or AWS_LAMBDA_FUNCTION_NAME (Lambda). Without it, AVP authorization \
will default-deny silently."
.into(),
)
})
}
pub fn to_server_config_entity(&self) -> crate::policy::ServerConfigEntity {
crate::policy::ServerConfigEntity {
server_id: self.server_id().to_string(),
server_type: "graphql".to_string(),
allow_write: self.allow_mutations,
allow_delete: self.allow_mutations,
allow_admin: self.allow_introspection,
allowed_operations: self.allowed_mutations.clone(),
blocked_operations: self.blocked_mutations.clone(),
max_depth: self.max_depth,
max_field_count: self.max_field_count,
max_cost: self.max_cost,
max_api_calls: 50,
blocked_fields: self.blocked_fields.clone(),
allowed_sensitive_categories: self.allowed_sensitive_categories.clone(),
}
}
#[cfg(feature = "openapi-code-mode")]
pub fn to_openapi_server_entity(&self) -> crate::policy::OpenAPIServerEntity {
let mut allowed_operations = self.openapi_allowed_writes.clone();
allowed_operations.extend(self.openapi_allowed_deletes.clone());
let write_mode = if !self.openapi_allow_writes {
"deny_all"
} else if !self.openapi_allowed_writes.is_empty() {
"allowlist"
} else if !self.openapi_blocked_writes.is_empty() {
"blocklist"
} else {
"allow_all"
};
crate::policy::OpenAPIServerEntity {
server_id: self.server_id().to_string(),
server_type: "openapi".to_string(),
allow_write: self.openapi_allow_writes,
allow_delete: self.openapi_allow_deletes,
allow_admin: false,
write_mode: write_mode.to_string(),
max_depth: self.max_depth,
max_cost: self.max_cost,
max_api_calls: 50,
max_loop_iterations: 100,
max_script_length: self.max_query_length as u32,
max_nesting_depth: self.max_depth,
execution_timeout_seconds: self.query_timeout_seconds,
allowed_operations,
blocked_operations: self.openapi_blocked_writes.clone(),
allowed_methods: HashSet::new(),
blocked_methods: HashSet::new(),
allowed_path_patterns: HashSet::new(),
blocked_path_patterns: self.openapi_blocked_paths.clone(),
sensitive_path_patterns: self.openapi_blocked_paths.clone(),
auto_approve_read_only: self.openapi_reads_enabled,
max_api_calls_for_auto_approve: 10,
internal_blocked_fields: self.openapi_internal_blocked_fields.clone(),
output_blocked_fields: self.openapi_output_blocked_fields.clone(),
require_output_declaration: self.openapi_require_output_declaration,
}
}
#[cfg(feature = "sql-code-mode")]
pub fn to_sql_server_entity(&self) -> crate::policy::SqlServerEntity {
crate::policy::SqlServerEntity {
server_id: self.server_id().to_string(),
server_type: "sql".to_string(),
allow_write: self.sql_allow_writes,
allow_delete: self.sql_allow_deletes,
allow_admin: self.sql_allow_ddl,
max_rows: self.sql_max_rows,
max_joins: self.sql_max_joins,
allowed_operations: self.sql_allowed_statements.clone(),
blocked_operations: self.sql_blocked_statements.clone(),
blocked_tables: self.sql_blocked_tables.clone(),
blocked_columns: self.sql_blocked_columns.clone(),
allowed_tables: self.sql_allowed_tables.clone(),
}
}
}
fn default_true() -> bool {
true
}
fn default_token_ttl() -> i64 {
300 }
fn default_auto_approve_levels() -> Vec<RiskLevel> {
vec![RiskLevel::Low]
}
fn default_max_query_length() -> usize {
10000
}
fn default_max_result_rows() -> usize {
10000
}
fn default_query_timeout() -> u32 {
30
}
fn default_max_depth() -> u32 {
10
}
fn default_max_field_count() -> u32 {
100
}
fn default_max_cost() -> u32 {
1000
}
fn default_sql_max_rows() -> u64 {
10_000
}
fn default_sql_max_joins() -> u32 {
5
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = CodeModeConfig::default();
assert!(!config.enabled);
assert!(!config.allow_mutations);
assert_eq!(config.token_ttl_seconds, 300);
assert_eq!(config.auto_approve_levels, vec![RiskLevel::Low]);
}
#[test]
fn test_enabled_config() {
let config = CodeModeConfig::enabled();
assert!(config.enabled);
}
#[test]
fn test_auto_approve() {
let config = CodeModeConfig::default();
assert!(config.should_auto_approve(RiskLevel::Low));
assert!(!config.should_auto_approve(RiskLevel::Medium));
assert!(!config.should_auto_approve(RiskLevel::High));
assert!(!config.should_auto_approve(RiskLevel::Critical));
}
#[test]
fn test_operation_registry_from_entries() {
let entries = vec![
OperationEntry {
id: "getCostAnomalies".to_string(),
category: "read".to_string(),
description: "Get cost anomalies".to_string(),
path: Some("/getCostAnomalies".to_string()),
},
OperationEntry {
id: "listInstances".to_string(),
category: "read".to_string(),
description: "List EC2 instances".to_string(),
path: Some("/listInstances".to_string()),
},
];
let registry = OperationRegistry::from_entries(&entries);
assert_eq!(
registry.lookup("/getCostAnomalies"),
Some("getCostAnomalies")
);
assert_eq!(registry.lookup("/listInstances"), Some("listInstances"));
}
#[test]
fn test_operation_registry_lookup_unregistered() {
let entries = vec![OperationEntry {
id: "getCostAnomalies".to_string(),
category: "read".to_string(),
description: String::new(),
path: Some("/getCostAnomalies".to_string()),
}];
let registry = OperationRegistry::from_entries(&entries);
assert_eq!(registry.lookup("/unknownPath"), None);
assert_eq!(registry.lookup(""), None);
}
#[test]
fn test_operation_registry_lookup_category() {
let entries = vec![
OperationEntry {
id: "getCostAnomalies".to_string(),
category: "read".to_string(),
description: String::new(),
path: Some("/getCostAnomalies".to_string()),
},
OperationEntry {
id: "deleteReservation".to_string(),
category: "delete".to_string(),
description: String::new(),
path: Some("/deleteReservation".to_string()),
},
OperationEntry {
id: "updateBudget".to_string(),
category: "write".to_string(),
description: String::new(),
path: Some("/updateBudget".to_string()),
},
];
let registry = OperationRegistry::from_entries(&entries);
assert_eq!(registry.lookup_category("/getCostAnomalies"), Some("read"));
assert_eq!(
registry.lookup_category("/deleteReservation"),
Some("delete")
);
assert_eq!(registry.lookup_category("/updateBudget"), Some("write"));
assert_eq!(registry.lookup_category("/unknownPath"), None);
}
#[test]
fn test_operation_registry_empty_category_excluded() {
let entries = vec![OperationEntry {
id: "legacyOp".to_string(),
category: String::new(), description: String::new(),
path: Some("/legacyOp".to_string()),
}];
let registry = OperationRegistry::from_entries(&entries);
assert_eq!(registry.lookup("/legacyOp"), Some("legacyOp"));
assert_eq!(registry.lookup_category("/legacyOp"), None);
}
#[test]
fn test_operation_registry_is_empty() {
let empty_registry = OperationRegistry::from_entries(&[]);
assert!(empty_registry.is_empty());
let entries = vec![OperationEntry {
id: "op1".to_string(),
category: "read".to_string(),
description: String::new(),
path: Some("/op1".to_string()),
}];
let registry = OperationRegistry::from_entries(&entries);
assert!(!registry.is_empty());
}
#[test]
fn test_operation_entry_deserialization() {
let toml_str = r#"
id = "getCostAnomalies"
category = "read"
description = "Get cost anomalies"
path = "/getCostAnomalies"
"#;
let entry: OperationEntry =
toml::from_str(toml_str).expect("Failed to deserialize OperationEntry");
assert_eq!(entry.id, "getCostAnomalies");
assert_eq!(entry.category, "read");
assert_eq!(entry.description, "Get cost anomalies");
assert_eq!(entry.path, Some("/getCostAnomalies".to_string()));
}
#[test]
fn test_code_mode_config_with_operations() {
let toml_str = r#"
enabled = true
[[operations]]
id = "getCostAnomalies"
category = "read"
description = "Get cost anomalies"
path = "/getCostAnomalies"
[[operations]]
id = "listInstances"
category = "read"
path = "/listInstances"
"#;
let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
assert!(config.enabled);
assert_eq!(config.operations.len(), 2);
assert_eq!(config.operations[0].id, "getCostAnomalies");
assert_eq!(config.operations[1].id, "listInstances");
}
#[test]
fn test_code_mode_config_without_operations_defaults_to_empty() {
let toml_str = r#"
enabled = true
"#;
let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
assert!(config.enabled);
assert!(config.operations.is_empty());
}
#[test]
fn test_from_toml_extracts_code_mode_section() {
let toml_str = r#"
[server]
name = "cost-coach"
type = "openapi-api"
[code_mode]
enabled = true
token_ttl_seconds = 600
server_id = "cost-coach"
[[code_mode.operations]]
id = "getCostAndUsage"
category = "read"
description = "Historical cost and usage data"
path = "/getCostAndUsage"
[[code_mode.operations]]
id = "getCostAnomalies"
category = "read"
description = "Cost anomalies detected by AWS"
path = "/getCostAnomalies"
[[tools]]
name = "some_tool"
"#;
let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
assert!(config.enabled);
assert_eq!(config.token_ttl_seconds, 600);
assert_eq!(config.server_id, Some("cost-coach".to_string()));
assert_eq!(config.operations.len(), 2);
assert_eq!(config.operations[0].id, "getCostAndUsage");
assert_eq!(config.operations[1].id, "getCostAnomalies");
assert_eq!(
config.operations[0].path,
Some("/getCostAndUsage".to_string())
);
}
#[test]
fn test_from_toml_missing_code_mode_returns_default() {
let toml_str = r#"
[server]
name = "some-server"
"#;
let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
assert!(!config.enabled);
assert!(config.operations.is_empty());
assert_eq!(config.token_ttl_seconds, 300); }
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
struct EnvGuard {
_lock: std::sync::MutexGuard<'static, ()>,
}
impl EnvGuard {
fn acquire() -> Self {
let lock = ENV_LOCK
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
std::env::remove_var("PMCP_SERVER_ID");
std::env::remove_var("AWS_LAMBDA_FUNCTION_NAME");
Self { _lock: lock }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
std::env::remove_var("PMCP_SERVER_ID");
std::env::remove_var("AWS_LAMBDA_FUNCTION_NAME");
}
}
#[test]
fn resolve_server_id_from_explicit_config_takes_precedence() {
let _g = EnvGuard::acquire();
std::env::set_var("PMCP_SERVER_ID", "from-env");
let mut config = CodeModeConfig {
server_id: Some("from-config".to_string()),
..Default::default()
};
config.resolve_server_id();
assert_eq!(config.server_id.as_deref(), Some("from-config"));
}
#[test]
fn resolve_server_id_from_pmcp_env() {
let _g = EnvGuard::acquire();
std::env::set_var("PMCP_SERVER_ID", "my-server");
let mut config = CodeModeConfig::default();
config.resolve_server_id();
assert_eq!(config.server_id.as_deref(), Some("my-server"));
}
#[test]
fn resolve_server_id_from_lambda_env() {
let _g = EnvGuard::acquire();
std::env::set_var("AWS_LAMBDA_FUNCTION_NAME", "my-lambda-fn");
let mut config = CodeModeConfig::default();
config.resolve_server_id();
assert_eq!(config.server_id.as_deref(), Some("my-lambda-fn"));
}
#[test]
fn resolve_server_id_pmcp_wins_over_lambda() {
let _g = EnvGuard::acquire();
std::env::set_var("PMCP_SERVER_ID", "explicit");
std::env::set_var("AWS_LAMBDA_FUNCTION_NAME", "lambda-fn");
let mut config = CodeModeConfig::default();
config.resolve_server_id();
assert_eq!(config.server_id.as_deref(), Some("explicit"));
}
#[test]
fn resolve_server_id_leaves_none_when_unset() {
let _g = EnvGuard::acquire();
let mut config = CodeModeConfig::default();
config.resolve_server_id();
assert!(config.server_id.is_none());
}
#[test]
fn require_server_id_errors_when_unset() {
let config = CodeModeConfig::default();
let result = config.require_server_id();
assert!(matches!(result, Err(ValidationError::ConfigError(_))));
}
#[test]
fn require_server_id_returns_value_when_set() {
let config = CodeModeConfig {
server_id: Some("my-server".to_string()),
..Default::default()
};
assert_eq!(config.require_server_id().unwrap(), "my-server");
}
#[test]
fn resolve_server_id_from_env_free_fn_treats_empty_as_unset() {
let _g = EnvGuard::acquire();
std::env::set_var("PMCP_SERVER_ID", "");
assert_eq!(resolve_server_id_from_env(), None);
}
#[test]
fn sql_config_accepts_unprefixed_toml_names() {
let toml_str = r#"
enabled = true
allow_writes = true
allow_deletes = true
allow_ddl = true
allowed_tables = ["users", "orders"]
blocked_tables = ["secrets"]
blocked_columns = ["password", "ssn"]
max_rows = 5000
max_joins = 3
require_where_on_writes = false
"#;
let config: CodeModeConfig =
toml::from_str(toml_str).expect("Failed to deserialize with unprefixed aliases");
assert!(config.enabled);
assert!(config.sql_allow_writes);
assert!(config.sql_allow_deletes);
assert!(config.sql_allow_ddl);
assert!(config.sql_allowed_tables.contains("users"));
assert!(config.sql_allowed_tables.contains("orders"));
assert!(config.sql_blocked_tables.contains("secrets"));
assert!(config.sql_blocked_columns.contains("password"));
assert_eq!(config.sql_max_rows, 5000);
assert_eq!(config.sql_max_joins, 3);
assert!(!config.sql_require_where_on_writes);
}
#[test]
fn sql_config_accepts_prefixed_toml_names() {
let toml_str = r#"
enabled = true
sql_allow_writes = true
sql_blocked_tables = ["secrets"]
sql_max_rows = 5000
"#;
let config: CodeModeConfig =
toml::from_str(toml_str).expect("Failed to deserialize with prefixed names");
assert!(config.sql_allow_writes);
assert!(config.sql_blocked_tables.contains("secrets"));
assert_eq!(config.sql_max_rows, 5000);
}
}