Skip to main content

pmcp_code_mode/
config.rs

1//! Code Mode configuration.
2
3use crate::types::{RiskLevel, ValidationError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::collections::HashSet;
7
8/// Resolve a `server_id` from environment variables.
9///
10/// Checks, in order:
11/// 1. `PMCP_SERVER_ID`
12/// 2. `AWS_LAMBDA_FUNCTION_NAME` (Lambda runtime)
13///
14/// Returns `None` if neither is set. Empty strings are treated as unset.
15///
16/// This is the same resolution chain used by
17/// [`CodeModeConfig::resolve_server_id`] — exposed as a free function so tests
18/// and non-pipeline code can share it.
19pub fn resolve_server_id_from_env() -> Option<String> {
20    let candidate = std::env::var("PMCP_SERVER_ID")
21        .ok()
22        .or_else(|| std::env::var("AWS_LAMBDA_FUNCTION_NAME").ok())?;
23    if candidate.is_empty() {
24        None
25    } else {
26        Some(candidate)
27    }
28}
29
30/// A single declared operation in Code Mode configuration.
31/// Maps a raw API path to a canonical plain-name ID for Cedar policies.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OperationEntry {
34    /// Canonical operation ID (plain name, no method prefix).
35    /// This is what appears in Cedar policy calledOperations.
36    pub id: String,
37
38    /// Action category for AVP action routing.
39    /// Values: "read", "write", "delete", "admin"
40    pub category: String,
41
42    /// Human-readable description for admin UI and LLM context.
43    #[serde(default)]
44    pub description: String,
45
46    /// Raw API path this ID maps to (e.g., "/getCostAnomalies").
47    /// Used to match against api_call.path from JavaScript analysis.
48    #[serde(default)]
49    pub path: Option<String>,
50}
51
52/// Registry built from [[code_mode.operations]] config entries.
53/// Maps raw paths to canonical operation IDs and categories.
54#[derive(Debug, Clone, Default)]
55pub struct OperationRegistry {
56    path_to_id: HashMap<String, String>,
57    path_to_category: HashMap<String, String>,
58}
59
60impl OperationRegistry {
61    pub fn from_entries(entries: &[OperationEntry]) -> Self {
62        let mut path_to_id = HashMap::with_capacity(entries.len());
63        let mut path_to_category = HashMap::with_capacity(entries.len());
64        for entry in entries {
65            if let Some(ref path) = entry.path {
66                path_to_id.insert(path.clone(), entry.id.clone());
67                if !entry.category.is_empty() {
68                    path_to_category.insert(path.clone(), entry.category.clone());
69                }
70            }
71        }
72        Self {
73            path_to_id,
74            path_to_category,
75        }
76    }
77
78    pub fn lookup(&self, path: &str) -> Option<&str> {
79        self.path_to_id.get(path).map(|s| s.as_str())
80    }
81
82    /// Look up the declared category for a path (e.g., "read", "write", "delete", "admin").
83    /// Returns `None` if the path has no registry entry or no category declared.
84    pub fn lookup_category(&self, path: &str) -> Option<&str> {
85        self.path_to_category.get(path).map(|s| s.as_str())
86    }
87
88    pub fn is_empty(&self) -> bool {
89        self.path_to_id.is_empty()
90    }
91}
92
93/// Configuration for Code Mode.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct CodeModeConfig {
96    /// Whether Code Mode is enabled for this server
97    #[serde(default)]
98    pub enabled: bool,
99
100    // ========================================================================
101    // GraphQL-specific settings
102    // ========================================================================
103    /// Whether to allow mutations (MVP: false)
104    #[serde(default)]
105    pub allow_mutations: bool,
106
107    /// Allowed mutation names (whitelist). If empty and allow_mutations=true, all are allowed.
108    #[serde(default)]
109    pub allowed_mutations: HashSet<String>,
110
111    /// Blocked mutation names (blacklist). Always blocked even if allow_mutations=true.
112    #[serde(default)]
113    pub blocked_mutations: HashSet<String>,
114
115    /// Whether to allow introspection queries
116    #[serde(default)]
117    pub allow_introspection: bool,
118
119    /// Fields that should never be returned (Type.field format) - GraphQL
120    #[serde(default)]
121    pub blocked_fields: HashSet<String>,
122
123    /// Allowed query names (whitelist). If empty and mode is allowlist, none are allowed.
124    #[serde(default)]
125    pub allowed_queries: HashSet<String>,
126
127    /// Blocked query names (blocklist). Always blocked even if reads enabled.
128    #[serde(default)]
129    pub blocked_queries: HashSet<String>,
130
131    // ========================================================================
132    // OpenAPI-specific settings
133    // ========================================================================
134    /// Whether read operations (GET) are enabled (default: true)
135    #[serde(default = "default_true")]
136    pub openapi_reads_enabled: bool,
137
138    /// Whether write operations (POST, PUT, PATCH) are allowed globally
139    #[serde(default)]
140    pub openapi_allow_writes: bool,
141
142    /// Allowed write operations (operationId or "METHOD /path")
143    #[serde(default)]
144    pub openapi_allowed_writes: HashSet<String>,
145
146    /// Blocked write operations
147    #[serde(default)]
148    pub openapi_blocked_writes: HashSet<String>,
149
150    /// Whether delete operations (DELETE) are allowed globally
151    #[serde(default)]
152    pub openapi_allow_deletes: bool,
153
154    /// Allowed delete operations (operationId or "METHOD /path")
155    #[serde(default)]
156    pub openapi_allowed_deletes: HashSet<String>,
157
158    /// Blocked paths (glob patterns like "/admin/*")
159    #[serde(default)]
160    pub openapi_blocked_paths: HashSet<String>,
161
162    /// Fields that are stripped from API responses entirely (no access)
163    #[serde(default)]
164    pub openapi_internal_blocked_fields: HashSet<String>,
165
166    /// Fields that can be used internally but not in script output
167    #[serde(default)]
168    pub openapi_output_blocked_fields: HashSet<String>,
169
170    /// Whether scripts must declare their return type with @returns
171    #[serde(default)]
172    pub openapi_require_output_declaration: bool,
173
174    // ========================================================================
175    // SQL-specific settings
176    //
177    // SQL fields accept both their prefixed name (`sql_allow_writes`) and the
178    // unprefixed natural form (`allow_writes`). Downstream SQL servers can use
179    // the unprefixed names in their `[code_mode]` block without a manual
180    // conversion layer:
181    //
182    //     [code_mode]
183    //     reads_enabled = true    # same as sql_reads_enabled
184    //     allow_writes = false    # same as sql_allow_writes
185    //     blocked_tables = ["secrets"]
186    //     max_rows = 5000
187    // ========================================================================
188    /// Whether SELECT statements are enabled (default: true).
189    #[serde(default = "default_true", alias = "reads_enabled")]
190    pub sql_reads_enabled: bool,
191
192    /// Whether INSERT/UPDATE/MERGE statements are allowed globally.
193    #[serde(default, alias = "allow_writes")]
194    pub sql_allow_writes: bool,
195
196    /// Whether DELETE/TRUNCATE statements are allowed globally.
197    #[serde(default, alias = "allow_deletes")]
198    pub sql_allow_deletes: bool,
199
200    /// Whether DDL (CREATE/ALTER/DROP/GRANT/REVOKE) is allowed globally.
201    /// Default is `false` — DDL is almost never appropriate for LLM-generated code.
202    #[serde(default, alias = "allow_ddl")]
203    pub sql_allow_ddl: bool,
204
205    /// Allowed statement types ("SELECT"/"INSERT"/"UPDATE"/"DELETE"/"DDL").
206    /// If non-empty, only statement types in this set are allowed.
207    #[serde(default, alias = "allowed_statements")]
208    pub sql_allowed_statements: HashSet<String>,
209
210    /// Blocked statement types. Always blocked even if globally allowed.
211    #[serde(default, alias = "blocked_statements")]
212    pub sql_blocked_statements: HashSet<String>,
213
214    /// Tables that are always forbidden (blocklist mode).
215    #[serde(default, alias = "blocked_tables")]
216    pub sql_blocked_tables: HashSet<String>,
217
218    /// If non-empty, only these tables can be accessed (allowlist mode).
219    #[serde(default, alias = "allowed_tables")]
220    pub sql_allowed_tables: HashSet<String>,
221
222    /// Columns that may not be referenced in any statement (e.g., `password`, `ssn`).
223    #[serde(default, alias = "blocked_columns")]
224    pub sql_blocked_columns: HashSet<String>,
225
226    /// Maximum row-count estimate allowed (based on LIMIT or default estimate).
227    #[serde(default = "default_sql_max_rows", alias = "max_rows")]
228    pub sql_max_rows: u64,
229
230    /// Maximum number of JOINs in a single statement.
231    #[serde(default = "default_sql_max_joins", alias = "max_joins")]
232    pub sql_max_joins: u32,
233
234    /// Whether to require a WHERE clause for UPDATE/DELETE statements.
235    #[serde(default = "default_true", alias = "require_where_on_writes")]
236    pub sql_require_where_on_writes: bool,
237
238    // ========================================================================
239    // Common settings
240    // ========================================================================
241    /// Action tags to override inferred actions for specific operations.
242    #[serde(default)]
243    pub action_tags: HashMap<String, String>,
244
245    /// Maximum query depth
246    #[serde(default = "default_max_depth")]
247    pub max_depth: u32,
248
249    /// Maximum field count per query
250    #[serde(default = "default_max_field_count")]
251    pub max_field_count: u32,
252
253    /// Maximum estimated query cost
254    #[serde(default = "default_max_cost")]
255    pub max_cost: u32,
256
257    /// Allowed sensitive data categories
258    #[serde(default)]
259    pub allowed_sensitive_categories: HashSet<String>,
260
261    /// Token time-to-live in seconds
262    #[serde(default = "default_token_ttl")]
263    pub token_ttl_seconds: i64,
264
265    /// Risk levels that can be auto-approved without human confirmation
266    #[serde(default = "default_auto_approve_levels")]
267    pub auto_approve_levels: Vec<RiskLevel>,
268
269    /// Maximum query length in characters
270    #[serde(default = "default_max_query_length")]
271    pub max_query_length: usize,
272
273    /// Maximum result rows to return
274    #[serde(default = "default_max_result_rows")]
275    pub max_result_rows: usize,
276
277    /// Query execution timeout in seconds
278    #[serde(default = "default_query_timeout")]
279    pub query_timeout_seconds: u32,
280
281    /// Server ID for token generation
282    #[serde(default)]
283    pub server_id: Option<String>,
284
285    // ========================================================================
286    // SDK-backed settings
287    // ========================================================================
288    /// Allowed SDK operation names for SDK-backed Code Mode.
289    /// When non-empty, Code Mode uses SDK dispatch instead of HTTP.
290    /// Operations are validated at compile time — unlisted names are rejected.
291    #[serde(default)]
292    pub sdk_operations: HashSet<String>,
293
294    /// Declared operations for plain-name ID mapping in Cedar entities.
295    /// Parsed from [[code_mode.operations]] TOML sections.
296    /// When non-empty, ScriptEntity calledOperations uses IDs from the registry
297    /// built from these entries. Unregistered paths fall back to METHOD:/path.
298    #[serde(default)]
299    pub operations: Vec<OperationEntry>,
300}
301
302impl Default for CodeModeConfig {
303    fn default() -> Self {
304        Self {
305            enabled: false,
306            // GraphQL
307            allow_mutations: false,
308            allowed_mutations: HashSet::new(),
309            blocked_mutations: HashSet::new(),
310            allow_introspection: false,
311            blocked_fields: HashSet::new(),
312            allowed_queries: HashSet::new(),
313            blocked_queries: HashSet::new(),
314            // OpenAPI
315            openapi_reads_enabled: true,
316            openapi_allow_writes: false,
317            openapi_allowed_writes: HashSet::new(),
318            openapi_blocked_writes: HashSet::new(),
319            openapi_allow_deletes: false,
320            openapi_allowed_deletes: HashSet::new(),
321            openapi_blocked_paths: HashSet::new(),
322            openapi_internal_blocked_fields: HashSet::new(),
323            openapi_output_blocked_fields: HashSet::new(),
324            openapi_require_output_declaration: false,
325            // SQL
326            sql_reads_enabled: true,
327            sql_allow_writes: false,
328            sql_allow_deletes: false,
329            sql_allow_ddl: false,
330            sql_allowed_statements: HashSet::new(),
331            sql_blocked_statements: HashSet::new(),
332            sql_blocked_tables: HashSet::new(),
333            sql_allowed_tables: HashSet::new(),
334            sql_blocked_columns: HashSet::new(),
335            sql_max_rows: default_sql_max_rows(),
336            sql_max_joins: default_sql_max_joins(),
337            sql_require_where_on_writes: true,
338            // Common
339            action_tags: HashMap::new(),
340            max_depth: default_max_depth(),
341            max_field_count: default_max_field_count(),
342            max_cost: default_max_cost(),
343            allowed_sensitive_categories: HashSet::new(),
344            token_ttl_seconds: default_token_ttl(),
345            auto_approve_levels: default_auto_approve_levels(),
346            max_query_length: default_max_query_length(),
347            max_result_rows: default_max_result_rows(),
348            query_timeout_seconds: default_query_timeout(),
349            server_id: None,
350            // SDK
351            sdk_operations: HashSet::new(),
352            operations: Vec::new(),
353        }
354    }
355}
356
357/// Wrapper for deserializing the `[code_mode]` section from a full TOML config file.
358/// The file may contain other sections (`[server]`, `[[tools]]`, etc.) which are ignored.
359#[derive(Deserialize)]
360struct TomlWrapper {
361    #[serde(default)]
362    code_mode: CodeModeConfig,
363}
364
365impl CodeModeConfig {
366    /// Parse `CodeModeConfig` from a full TOML config string.
367    ///
368    /// Extracts the `[code_mode]` section (including `[[code_mode.operations]]`)
369    /// and ignores all other sections. This is the recommended way for external
370    /// servers to build their config from `config.toml`:
371    ///
372    /// ```rust,ignore
373    /// const CONFIG_TOML: &str = include_str!("../../config.toml");
374    ///
375    /// let config = CodeModeConfig::from_toml(CONFIG_TOML)
376    ///     .expect("Invalid code_mode section in config.toml");
377    /// ```
378    ///
379    /// If the TOML has no `[code_mode]` section, returns `CodeModeConfig::default()`.
380    pub fn from_toml(toml_str: &str) -> Result<Self, toml::de::Error> {
381        let wrapper: TomlWrapper = toml::from_str(toml_str)?;
382        Ok(wrapper.code_mode)
383    }
384
385    /// Create a new config with Code Mode enabled.
386    pub fn enabled() -> Self {
387        Self {
388            enabled: true,
389            ..Default::default()
390        }
391    }
392
393    /// Returns true if this config enables SDK-backed Code Mode.
394    pub fn is_sdk_mode(&self) -> bool {
395        !self.sdk_operations.is_empty()
396    }
397
398    /// Check if a risk level should be auto-approved.
399    pub fn should_auto_approve(&self, risk_level: RiskLevel) -> bool {
400        self.auto_approve_levels.contains(&risk_level)
401    }
402
403    /// Get the server ID, falling back to a default.
404    ///
405    /// **Note:** The `"unknown"` fallback produces silent AVP default-deny failures
406    /// (no Cedar policy matches a server_id of `"unknown"`). Prefer
407    /// [`resolve_server_id`](Self::resolve_server_id) to auto-fill from environment,
408    /// or [`require_server_id`](Self::require_server_id) to fail fast.
409    pub fn server_id(&self) -> &str {
410        self.server_id.as_deref().unwrap_or("unknown")
411    }
412
413    /// Auto-resolve `server_id` from environment if not already set.
414    ///
415    /// Resolution order:
416    /// 1. `self.server_id` (if already set, e.g., from TOML) — no change
417    /// 2. `PMCP_SERVER_ID` env var
418    /// 3. `AWS_LAMBDA_FUNCTION_NAME` env var (Lambda runtime)
419    /// 4. Left as `None` — caller is responsible for handling
420    ///
421    /// [`ValidationPipeline`](crate::ValidationPipeline) constructors call this
422    /// automatically, so wrappers rarely need to invoke it directly.
423    pub fn resolve_server_id(&mut self) {
424        if self.server_id.is_some() {
425            return;
426        }
427        self.server_id = resolve_server_id_from_env();
428    }
429
430    /// Return the `server_id`, or an error if not resolved.
431    ///
432    /// Use this in production code paths that require AVP authorization —
433    /// it fails fast with a clear message instead of letting `"unknown"`
434    /// reach AVP and produce a silent default-deny.
435    pub fn require_server_id(&self) -> Result<&str, ValidationError> {
436        self.server_id.as_deref().ok_or_else(|| {
437            ValidationError::ConfigError(
438                "server_id is not set. Set it in config.toml, PMCP_SERVER_ID env var, \
439                 or AWS_LAMBDA_FUNCTION_NAME (Lambda). Without it, AVP authorization \
440                 will default-deny silently."
441                    .into(),
442            )
443        })
444    }
445
446    /// Convert to ServerConfigEntity for policy evaluation.
447    pub fn to_server_config_entity(&self) -> crate::policy::ServerConfigEntity {
448        crate::policy::ServerConfigEntity {
449            server_id: self.server_id().to_string(),
450            server_type: "graphql".to_string(),
451            allow_write: self.allow_mutations,
452            allow_delete: self.allow_mutations,
453            allow_admin: self.allow_introspection,
454            allowed_operations: self.allowed_mutations.clone(),
455            blocked_operations: self.blocked_mutations.clone(),
456            max_depth: self.max_depth,
457            max_field_count: self.max_field_count,
458            max_cost: self.max_cost,
459            max_api_calls: 50,
460            blocked_fields: self.blocked_fields.clone(),
461            allowed_sensitive_categories: self.allowed_sensitive_categories.clone(),
462        }
463    }
464
465    /// Convert to OpenAPIServerEntity for policy evaluation (OpenAPI Code Mode).
466    #[cfg(feature = "openapi-code-mode")]
467    pub fn to_openapi_server_entity(&self) -> crate::policy::OpenAPIServerEntity {
468        let mut allowed_operations = self.openapi_allowed_writes.clone();
469        allowed_operations.extend(self.openapi_allowed_deletes.clone());
470
471        let write_mode = if !self.openapi_allow_writes {
472            "deny_all"
473        } else if !self.openapi_allowed_writes.is_empty() {
474            "allowlist"
475        } else if !self.openapi_blocked_writes.is_empty() {
476            "blocklist"
477        } else {
478            "allow_all"
479        };
480
481        crate::policy::OpenAPIServerEntity {
482            server_id: self.server_id().to_string(),
483            server_type: "openapi".to_string(),
484            allow_write: self.openapi_allow_writes,
485            allow_delete: self.openapi_allow_deletes,
486            allow_admin: false,
487            write_mode: write_mode.to_string(),
488            max_depth: self.max_depth,
489            max_cost: self.max_cost,
490            max_api_calls: 50,
491            max_loop_iterations: 100,
492            max_script_length: self.max_query_length as u32,
493            max_nesting_depth: self.max_depth,
494            execution_timeout_seconds: self.query_timeout_seconds,
495            allowed_operations,
496            blocked_operations: self.openapi_blocked_writes.clone(),
497            allowed_methods: HashSet::new(),
498            blocked_methods: HashSet::new(),
499            allowed_path_patterns: HashSet::new(),
500            blocked_path_patterns: self.openapi_blocked_paths.clone(),
501            sensitive_path_patterns: self.openapi_blocked_paths.clone(),
502            auto_approve_read_only: self.openapi_reads_enabled,
503            max_api_calls_for_auto_approve: 10,
504            internal_blocked_fields: self.openapi_internal_blocked_fields.clone(),
505            output_blocked_fields: self.openapi_output_blocked_fields.clone(),
506            require_output_declaration: self.openapi_require_output_declaration,
507        }
508    }
509
510    /// Convert to `SqlServerEntity` for policy evaluation (SQL Code Mode).
511    #[cfg(feature = "sql-code-mode")]
512    pub fn to_sql_server_entity(&self) -> crate::policy::SqlServerEntity {
513        crate::policy::SqlServerEntity {
514            server_id: self.server_id().to_string(),
515            server_type: "sql".to_string(),
516            allow_write: self.sql_allow_writes,
517            allow_delete: self.sql_allow_deletes,
518            allow_admin: self.sql_allow_ddl,
519            max_rows: self.sql_max_rows,
520            max_joins: self.sql_max_joins,
521            allowed_operations: self.sql_allowed_statements.clone(),
522            blocked_operations: self.sql_blocked_statements.clone(),
523            blocked_tables: self.sql_blocked_tables.clone(),
524            blocked_columns: self.sql_blocked_columns.clone(),
525            allowed_tables: self.sql_allowed_tables.clone(),
526        }
527    }
528}
529
530fn default_true() -> bool {
531    true
532}
533
534fn default_token_ttl() -> i64 {
535    300 // 5 minutes
536}
537
538fn default_auto_approve_levels() -> Vec<RiskLevel> {
539    vec![RiskLevel::Low]
540}
541
542fn default_max_query_length() -> usize {
543    10000
544}
545
546fn default_max_result_rows() -> usize {
547    10000
548}
549
550fn default_query_timeout() -> u32 {
551    30
552}
553
554fn default_max_depth() -> u32 {
555    10
556}
557
558fn default_max_field_count() -> u32 {
559    100
560}
561
562fn default_max_cost() -> u32 {
563    1000
564}
565
566fn default_sql_max_rows() -> u64 {
567    10_000
568}
569
570fn default_sql_max_joins() -> u32 {
571    5
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[test]
579    fn test_default_config() {
580        let config = CodeModeConfig::default();
581        assert!(!config.enabled);
582        assert!(!config.allow_mutations);
583        assert_eq!(config.token_ttl_seconds, 300);
584        assert_eq!(config.auto_approve_levels, vec![RiskLevel::Low]);
585    }
586
587    #[test]
588    fn test_enabled_config() {
589        let config = CodeModeConfig::enabled();
590        assert!(config.enabled);
591    }
592
593    #[test]
594    fn test_auto_approve() {
595        let config = CodeModeConfig::default();
596        assert!(config.should_auto_approve(RiskLevel::Low));
597        assert!(!config.should_auto_approve(RiskLevel::Medium));
598        assert!(!config.should_auto_approve(RiskLevel::High));
599        assert!(!config.should_auto_approve(RiskLevel::Critical));
600    }
601
602    #[test]
603    fn test_operation_registry_from_entries() {
604        let entries = vec![
605            OperationEntry {
606                id: "getCostAnomalies".to_string(),
607                category: "read".to_string(),
608                description: "Get cost anomalies".to_string(),
609                path: Some("/getCostAnomalies".to_string()),
610            },
611            OperationEntry {
612                id: "listInstances".to_string(),
613                category: "read".to_string(),
614                description: "List EC2 instances".to_string(),
615                path: Some("/listInstances".to_string()),
616            },
617        ];
618        let registry = OperationRegistry::from_entries(&entries);
619        assert_eq!(
620            registry.lookup("/getCostAnomalies"),
621            Some("getCostAnomalies")
622        );
623        assert_eq!(registry.lookup("/listInstances"), Some("listInstances"));
624    }
625
626    #[test]
627    fn test_operation_registry_lookup_unregistered() {
628        let entries = vec![OperationEntry {
629            id: "getCostAnomalies".to_string(),
630            category: "read".to_string(),
631            description: String::new(),
632            path: Some("/getCostAnomalies".to_string()),
633        }];
634        let registry = OperationRegistry::from_entries(&entries);
635        assert_eq!(registry.lookup("/unknownPath"), None);
636        assert_eq!(registry.lookup(""), None);
637    }
638
639    #[test]
640    fn test_operation_registry_lookup_category() {
641        let entries = vec![
642            OperationEntry {
643                id: "getCostAnomalies".to_string(),
644                category: "read".to_string(),
645                description: String::new(),
646                path: Some("/getCostAnomalies".to_string()),
647            },
648            OperationEntry {
649                id: "deleteReservation".to_string(),
650                category: "delete".to_string(),
651                description: String::new(),
652                path: Some("/deleteReservation".to_string()),
653            },
654            OperationEntry {
655                id: "updateBudget".to_string(),
656                category: "write".to_string(),
657                description: String::new(),
658                path: Some("/updateBudget".to_string()),
659            },
660        ];
661        let registry = OperationRegistry::from_entries(&entries);
662        assert_eq!(registry.lookup_category("/getCostAnomalies"), Some("read"));
663        assert_eq!(
664            registry.lookup_category("/deleteReservation"),
665            Some("delete")
666        );
667        assert_eq!(registry.lookup_category("/updateBudget"), Some("write"));
668        assert_eq!(registry.lookup_category("/unknownPath"), None);
669    }
670
671    #[test]
672    fn test_operation_registry_empty_category_excluded() {
673        let entries = vec![OperationEntry {
674            id: "legacyOp".to_string(),
675            category: String::new(), // empty = not declared
676            description: String::new(),
677            path: Some("/legacyOp".to_string()),
678        }];
679        let registry = OperationRegistry::from_entries(&entries);
680        // ID lookup still works
681        assert_eq!(registry.lookup("/legacyOp"), Some("legacyOp"));
682        // Category lookup returns None for empty category
683        assert_eq!(registry.lookup_category("/legacyOp"), None);
684    }
685
686    #[test]
687    fn test_operation_registry_is_empty() {
688        let empty_registry = OperationRegistry::from_entries(&[]);
689        assert!(empty_registry.is_empty());
690
691        let entries = vec![OperationEntry {
692            id: "op1".to_string(),
693            category: "read".to_string(),
694            description: String::new(),
695            path: Some("/op1".to_string()),
696        }];
697        let registry = OperationRegistry::from_entries(&entries);
698        assert!(!registry.is_empty());
699    }
700
701    #[test]
702    fn test_operation_entry_deserialization() {
703        let toml_str = r#"
704id = "getCostAnomalies"
705category = "read"
706description = "Get cost anomalies"
707path = "/getCostAnomalies"
708"#;
709        let entry: OperationEntry =
710            toml::from_str(toml_str).expect("Failed to deserialize OperationEntry");
711        assert_eq!(entry.id, "getCostAnomalies");
712        assert_eq!(entry.category, "read");
713        assert_eq!(entry.description, "Get cost anomalies");
714        assert_eq!(entry.path, Some("/getCostAnomalies".to_string()));
715    }
716
717    #[test]
718    fn test_code_mode_config_with_operations() {
719        let toml_str = r#"
720enabled = true
721
722[[operations]]
723id = "getCostAnomalies"
724category = "read"
725description = "Get cost anomalies"
726path = "/getCostAnomalies"
727
728[[operations]]
729id = "listInstances"
730category = "read"
731path = "/listInstances"
732"#;
733        let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
734        assert!(config.enabled);
735        assert_eq!(config.operations.len(), 2);
736        assert_eq!(config.operations[0].id, "getCostAnomalies");
737        assert_eq!(config.operations[1].id, "listInstances");
738    }
739
740    #[test]
741    fn test_code_mode_config_without_operations_defaults_to_empty() {
742        let toml_str = r#"
743enabled = true
744"#;
745        let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
746        assert!(config.enabled);
747        assert!(config.operations.is_empty());
748    }
749
750    #[test]
751    fn test_from_toml_extracts_code_mode_section() {
752        let toml_str = r#"
753[server]
754name = "cost-coach"
755type = "openapi-api"
756
757[code_mode]
758enabled = true
759token_ttl_seconds = 600
760server_id = "cost-coach"
761
762[[code_mode.operations]]
763id = "getCostAndUsage"
764category = "read"
765description = "Historical cost and usage data"
766path = "/getCostAndUsage"
767
768[[code_mode.operations]]
769id = "getCostAnomalies"
770category = "read"
771description = "Cost anomalies detected by AWS"
772path = "/getCostAnomalies"
773
774[[tools]]
775name = "some_tool"
776"#;
777        let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
778        assert!(config.enabled);
779        assert_eq!(config.token_ttl_seconds, 600);
780        assert_eq!(config.server_id, Some("cost-coach".to_string()));
781        assert_eq!(config.operations.len(), 2);
782        assert_eq!(config.operations[0].id, "getCostAndUsage");
783        assert_eq!(config.operations[1].id, "getCostAnomalies");
784        assert_eq!(
785            config.operations[0].path,
786            Some("/getCostAndUsage".to_string())
787        );
788    }
789
790    #[test]
791    fn test_from_toml_missing_code_mode_returns_default() {
792        let toml_str = r#"
793[server]
794name = "some-server"
795"#;
796        let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
797        assert!(!config.enabled);
798        assert!(config.operations.is_empty());
799        assert_eq!(config.token_ttl_seconds, 300); // default
800    }
801
802    // =========================================================================
803    // server_id resolution tests
804    //
805    // These tests mutate process-wide env vars. Cargo parallelizes tests across
806    // threads in the same process, so a shared Mutex serializes them — without
807    // this, set_var/remove_var in one test would race with another.
808    // =========================================================================
809
810    use std::sync::Mutex;
811    static ENV_LOCK: Mutex<()> = Mutex::new(());
812
813    struct EnvGuard {
814        _lock: std::sync::MutexGuard<'static, ()>,
815    }
816
817    impl EnvGuard {
818        fn acquire() -> Self {
819            let lock = ENV_LOCK
820                .lock()
821                .unwrap_or_else(|poisoned| poisoned.into_inner());
822            std::env::remove_var("PMCP_SERVER_ID");
823            std::env::remove_var("AWS_LAMBDA_FUNCTION_NAME");
824            Self { _lock: lock }
825        }
826    }
827
828    impl Drop for EnvGuard {
829        fn drop(&mut self) {
830            std::env::remove_var("PMCP_SERVER_ID");
831            std::env::remove_var("AWS_LAMBDA_FUNCTION_NAME");
832        }
833    }
834
835    #[test]
836    fn resolve_server_id_from_explicit_config_takes_precedence() {
837        let _g = EnvGuard::acquire();
838        std::env::set_var("PMCP_SERVER_ID", "from-env");
839
840        let mut config = CodeModeConfig {
841            server_id: Some("from-config".to_string()),
842            ..Default::default()
843        };
844        config.resolve_server_id();
845
846        assert_eq!(config.server_id.as_deref(), Some("from-config"));
847    }
848
849    #[test]
850    fn resolve_server_id_from_pmcp_env() {
851        let _g = EnvGuard::acquire();
852        std::env::set_var("PMCP_SERVER_ID", "my-server");
853
854        let mut config = CodeModeConfig::default();
855        config.resolve_server_id();
856
857        assert_eq!(config.server_id.as_deref(), Some("my-server"));
858    }
859
860    #[test]
861    fn resolve_server_id_from_lambda_env() {
862        let _g = EnvGuard::acquire();
863        std::env::set_var("AWS_LAMBDA_FUNCTION_NAME", "my-lambda-fn");
864
865        let mut config = CodeModeConfig::default();
866        config.resolve_server_id();
867
868        assert_eq!(config.server_id.as_deref(), Some("my-lambda-fn"));
869    }
870
871    #[test]
872    fn resolve_server_id_pmcp_wins_over_lambda() {
873        let _g = EnvGuard::acquire();
874        std::env::set_var("PMCP_SERVER_ID", "explicit");
875        std::env::set_var("AWS_LAMBDA_FUNCTION_NAME", "lambda-fn");
876
877        let mut config = CodeModeConfig::default();
878        config.resolve_server_id();
879
880        assert_eq!(config.server_id.as_deref(), Some("explicit"));
881    }
882
883    #[test]
884    fn resolve_server_id_leaves_none_when_unset() {
885        let _g = EnvGuard::acquire();
886        let mut config = CodeModeConfig::default();
887        config.resolve_server_id();
888        assert!(config.server_id.is_none());
889    }
890
891    #[test]
892    fn require_server_id_errors_when_unset() {
893        let config = CodeModeConfig::default();
894        let result = config.require_server_id();
895        assert!(matches!(result, Err(ValidationError::ConfigError(_))));
896    }
897
898    #[test]
899    fn require_server_id_returns_value_when_set() {
900        let config = CodeModeConfig {
901            server_id: Some("my-server".to_string()),
902            ..Default::default()
903        };
904        assert_eq!(config.require_server_id().unwrap(), "my-server");
905    }
906
907    #[test]
908    fn resolve_server_id_from_env_free_fn_treats_empty_as_unset() {
909        let _g = EnvGuard::acquire();
910        std::env::set_var("PMCP_SERVER_ID", "");
911        assert_eq!(resolve_server_id_from_env(), None);
912    }
913
914    // =========================================================================
915    // SQL TOML DX tests (serde aliases)
916    // =========================================================================
917
918    #[test]
919    fn sql_config_accepts_unprefixed_toml_names() {
920        let toml_str = r#"
921enabled = true
922allow_writes = true
923allow_deletes = true
924allow_ddl = true
925allowed_tables = ["users", "orders"]
926blocked_tables = ["secrets"]
927blocked_columns = ["password", "ssn"]
928max_rows = 5000
929max_joins = 3
930require_where_on_writes = false
931"#;
932        let config: CodeModeConfig =
933            toml::from_str(toml_str).expect("Failed to deserialize with unprefixed aliases");
934
935        assert!(config.enabled);
936        assert!(config.sql_allow_writes);
937        assert!(config.sql_allow_deletes);
938        assert!(config.sql_allow_ddl);
939        assert!(config.sql_allowed_tables.contains("users"));
940        assert!(config.sql_allowed_tables.contains("orders"));
941        assert!(config.sql_blocked_tables.contains("secrets"));
942        assert!(config.sql_blocked_columns.contains("password"));
943        assert_eq!(config.sql_max_rows, 5000);
944        assert_eq!(config.sql_max_joins, 3);
945        assert!(!config.sql_require_where_on_writes);
946    }
947
948    #[test]
949    fn sql_config_accepts_prefixed_toml_names() {
950        let toml_str = r#"
951enabled = true
952sql_allow_writes = true
953sql_blocked_tables = ["secrets"]
954sql_max_rows = 5000
955"#;
956        let config: CodeModeConfig =
957            toml::from_str(toml_str).expect("Failed to deserialize with prefixed names");
958
959        assert!(config.sql_allow_writes);
960        assert!(config.sql_blocked_tables.contains("secrets"));
961        assert_eq!(config.sql_max_rows, 5000);
962    }
963}