Skip to main content

pmcp_code_mode/
config.rs

1//! Code Mode configuration.
2
3use crate::types::{RiskLevel, ValidationError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::collections::HashSet;
7
8/// Resolve a `server_id` from environment variables.
9///
10/// Checks, in order:
11/// 1. `PMCP_SERVER_ID`
12/// 2. `AWS_LAMBDA_FUNCTION_NAME` (Lambda runtime)
13///
14/// Returns `None` if neither is set. Empty strings are treated as unset.
15///
16/// This is the same resolution chain used by
17/// [`CodeModeConfig::resolve_server_id`] — exposed as a free function so tests
18/// and non-pipeline code can share it.
19pub fn resolve_server_id_from_env() -> Option<String> {
20    let candidate = std::env::var("PMCP_SERVER_ID")
21        .ok()
22        .or_else(|| std::env::var("AWS_LAMBDA_FUNCTION_NAME").ok())?;
23    if candidate.is_empty() {
24        None
25    } else {
26        Some(candidate)
27    }
28}
29
30/// A single declared operation in Code Mode configuration.
31/// Maps a raw API path to a canonical plain-name ID for Cedar policies.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OperationEntry {
34    /// Canonical operation ID (plain name, no method prefix).
35    /// This is what appears in Cedar policy calledOperations.
36    pub id: String,
37
38    /// Action category for AVP action routing.
39    /// Values: "read", "write", "delete", "admin"
40    pub category: String,
41
42    /// Human-readable description for admin UI and LLM context.
43    #[serde(default)]
44    pub description: String,
45
46    /// Raw API path this ID maps to (e.g., "/getCostAnomalies").
47    /// Used to match against api_call.path from JavaScript analysis.
48    #[serde(default)]
49    pub path: Option<String>,
50}
51
52/// Registry built from [[code_mode.operations]] config entries.
53/// Maps raw paths to canonical operation IDs and categories.
54#[derive(Debug, Clone, Default)]
55pub struct OperationRegistry {
56    path_to_id: HashMap<String, String>,
57    path_to_category: HashMap<String, String>,
58}
59
60impl OperationRegistry {
61    pub fn from_entries(entries: &[OperationEntry]) -> Self {
62        let mut path_to_id = HashMap::with_capacity(entries.len());
63        let mut path_to_category = HashMap::with_capacity(entries.len());
64        for entry in entries {
65            if let Some(ref path) = entry.path {
66                path_to_id.insert(path.clone(), entry.id.clone());
67                if !entry.category.is_empty() {
68                    path_to_category.insert(path.clone(), entry.category.clone());
69                }
70            }
71        }
72        Self {
73            path_to_id,
74            path_to_category,
75        }
76    }
77
78    pub fn lookup(&self, path: &str) -> Option<&str> {
79        self.path_to_id.get(path).map(|s| s.as_str())
80    }
81
82    /// Look up the declared category for a path (e.g., "read", "write", "delete", "admin").
83    /// Returns `None` if the path has no registry entry or no category declared.
84    pub fn lookup_category(&self, path: &str) -> Option<&str> {
85        self.path_to_category.get(path).map(|s| s.as_str())
86    }
87
88    pub fn is_empty(&self) -> bool {
89        self.path_to_id.is_empty()
90    }
91}
92
93/// Configuration for Code Mode.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct CodeModeConfig {
96    /// Whether Code Mode is enabled for this server
97    #[serde(default)]
98    pub enabled: bool,
99
100    // ========================================================================
101    // GraphQL-specific settings
102    // ========================================================================
103    /// Whether to allow mutations (MVP: false)
104    #[serde(default)]
105    pub allow_mutations: bool,
106
107    /// Allowed mutation names (whitelist). If empty and allow_mutations=true, all are allowed.
108    #[serde(default)]
109    pub allowed_mutations: HashSet<String>,
110
111    /// Blocked mutation names (blacklist). Always blocked even if allow_mutations=true.
112    #[serde(default)]
113    pub blocked_mutations: HashSet<String>,
114
115    /// Whether to allow introspection queries
116    #[serde(default)]
117    pub allow_introspection: bool,
118
119    /// Fields that should never be returned (Type.field format) - GraphQL
120    #[serde(default)]
121    pub blocked_fields: HashSet<String>,
122
123    /// Allowed query names (whitelist). If empty and mode is allowlist, none are allowed.
124    #[serde(default)]
125    pub allowed_queries: HashSet<String>,
126
127    /// Blocked query names (blocklist). Always blocked even if reads enabled.
128    #[serde(default)]
129    pub blocked_queries: HashSet<String>,
130
131    // ========================================================================
132    // OpenAPI-specific settings
133    // ========================================================================
134    /// Whether read operations (GET) are enabled (default: true)
135    #[serde(default = "default_true")]
136    pub openapi_reads_enabled: bool,
137
138    /// Whether write operations (POST, PUT, PATCH) are allowed globally
139    #[serde(default)]
140    pub openapi_allow_writes: bool,
141
142    /// Allowed write operations (operationId or "METHOD /path")
143    #[serde(default)]
144    pub openapi_allowed_writes: HashSet<String>,
145
146    /// Blocked write operations
147    #[serde(default)]
148    pub openapi_blocked_writes: HashSet<String>,
149
150    /// Whether delete operations (DELETE) are allowed globally
151    #[serde(default)]
152    pub openapi_allow_deletes: bool,
153
154    /// Allowed delete operations (operationId or "METHOD /path")
155    #[serde(default)]
156    pub openapi_allowed_deletes: HashSet<String>,
157
158    /// Blocked paths (glob patterns like "/admin/*")
159    #[serde(default)]
160    pub openapi_blocked_paths: HashSet<String>,
161
162    /// Fields that are stripped from API responses entirely (no access)
163    #[serde(default)]
164    pub openapi_internal_blocked_fields: HashSet<String>,
165
166    /// Fields that can be used internally but not in script output
167    #[serde(default)]
168    pub openapi_output_blocked_fields: HashSet<String>,
169
170    /// Whether scripts must declare their return type with @returns
171    #[serde(default)]
172    pub openapi_require_output_declaration: bool,
173
174    // ========================================================================
175    // SQL-specific settings
176    //
177    // SQL fields accept both their prefixed name (`sql_allow_writes`) and the
178    // unprefixed natural form (`allow_writes`). Downstream SQL servers can use
179    // the unprefixed names in their `[code_mode]` block without a manual
180    // conversion layer:
181    //
182    //     [code_mode]
183    //     reads_enabled = true    # same as sql_reads_enabled
184    //     allow_writes = false    # same as sql_allow_writes
185    //     blocked_tables = ["secrets"]
186    //     max_rows = 5000
187    // ========================================================================
188    /// Whether SELECT statements are enabled (default: true).
189    #[serde(default = "default_true", alias = "reads_enabled")]
190    pub sql_reads_enabled: bool,
191
192    /// Whether INSERT/UPDATE/MERGE statements are allowed globally.
193    #[serde(default, alias = "allow_writes")]
194    pub sql_allow_writes: bool,
195
196    /// Whether DELETE/TRUNCATE statements are allowed globally.
197    #[serde(default, alias = "allow_deletes")]
198    pub sql_allow_deletes: bool,
199
200    /// Whether DDL (CREATE/ALTER/DROP/GRANT/REVOKE) is allowed globally.
201    /// Default is `false` — DDL is almost never appropriate for LLM-generated code.
202    #[serde(default, alias = "allow_ddl")]
203    pub sql_allow_ddl: bool,
204
205    /// Allowed statement types ("SELECT"/"INSERT"/"UPDATE"/"DELETE"/"DDL").
206    /// If non-empty, only statement types in this set are allowed.
207    #[serde(default, alias = "allowed_statements")]
208    pub sql_allowed_statements: HashSet<String>,
209
210    /// Blocked statement types. Always blocked even if globally allowed.
211    #[serde(default, alias = "blocked_statements")]
212    pub sql_blocked_statements: HashSet<String>,
213
214    /// Tables that are always forbidden (blocklist mode).
215    #[serde(default, alias = "blocked_tables")]
216    pub sql_blocked_tables: HashSet<String>,
217
218    /// If non-empty, only these tables can be accessed (allowlist mode).
219    #[serde(default, alias = "allowed_tables")]
220    pub sql_allowed_tables: HashSet<String>,
221
222    /// Columns that may not be referenced in any statement (e.g., `password`, `ssn`).
223    #[serde(default, alias = "blocked_columns")]
224    pub sql_blocked_columns: HashSet<String>,
225
226    /// Maximum row-count estimate allowed (based on LIMIT or default estimate).
227    #[serde(default = "default_sql_max_rows", alias = "max_rows")]
228    pub sql_max_rows: u64,
229
230    /// Maximum number of JOINs in a single statement.
231    #[serde(default = "default_sql_max_joins", alias = "max_joins")]
232    pub sql_max_joins: u32,
233
234    /// Whether to require a WHERE clause for UPDATE/DELETE statements.
235    #[serde(default = "default_true", alias = "require_where_on_writes")]
236    pub sql_require_where_on_writes: bool,
237
238    /// Whether read-only (SELECT-class) statements MUST declare a LIMIT.
239    /// Opt-in safety guard; default false (no behavior change for configs
240    /// that omit it). Enforced in check_sql_config_authorization.
241    #[serde(default, alias = "require_limit")]
242    pub sql_require_limit: bool,
243
244    // ========================================================================
245    // Common settings
246    // ========================================================================
247    /// Action tags to override inferred actions for specific operations.
248    #[serde(default)]
249    pub action_tags: HashMap<String, String>,
250
251    /// Maximum query depth
252    #[serde(default = "default_max_depth")]
253    pub max_depth: u32,
254
255    /// Maximum field count per query
256    #[serde(default = "default_max_field_count")]
257    pub max_field_count: u32,
258
259    /// Maximum estimated query cost
260    #[serde(default = "default_max_cost")]
261    pub max_cost: u32,
262
263    /// Allowed sensitive data categories
264    #[serde(default)]
265    pub allowed_sensitive_categories: HashSet<String>,
266
267    /// Token time-to-live in seconds
268    #[serde(default = "default_token_ttl")]
269    pub token_ttl_seconds: i64,
270
271    /// Risk levels that can be auto-approved without human confirmation
272    #[serde(default = "default_auto_approve_levels")]
273    pub auto_approve_levels: Vec<RiskLevel>,
274
275    /// Maximum query length in characters
276    #[serde(default = "default_max_query_length")]
277    pub max_query_length: usize,
278
279    /// Maximum result rows to return
280    #[serde(default = "default_max_result_rows")]
281    pub max_result_rows: usize,
282
283    /// Query execution timeout in seconds
284    #[serde(default = "default_query_timeout")]
285    pub query_timeout_seconds: u32,
286
287    /// Server ID for token generation
288    #[serde(default)]
289    pub server_id: Option<String>,
290
291    // ========================================================================
292    // SDK-backed settings
293    // ========================================================================
294    /// Allowed SDK operation names for SDK-backed Code Mode.
295    /// When non-empty, Code Mode uses SDK dispatch instead of HTTP.
296    /// Operations are validated at compile time — unlisted names are rejected.
297    #[serde(default)]
298    pub sdk_operations: HashSet<String>,
299
300    /// Declared operations for plain-name ID mapping in Cedar entities.
301    /// Parsed from [[code_mode.operations]] TOML sections.
302    /// When non-empty, ScriptEntity calledOperations uses IDs from the registry
303    /// built from these entries. Unregistered paths fall back to METHOD:/path.
304    #[serde(default)]
305    pub operations: Vec<OperationEntry>,
306}
307
308impl Default for CodeModeConfig {
309    fn default() -> Self {
310        Self {
311            enabled: false,
312            // GraphQL
313            allow_mutations: false,
314            allowed_mutations: HashSet::new(),
315            blocked_mutations: HashSet::new(),
316            allow_introspection: false,
317            blocked_fields: HashSet::new(),
318            allowed_queries: HashSet::new(),
319            blocked_queries: HashSet::new(),
320            // OpenAPI
321            openapi_reads_enabled: true,
322            openapi_allow_writes: false,
323            openapi_allowed_writes: HashSet::new(),
324            openapi_blocked_writes: HashSet::new(),
325            openapi_allow_deletes: false,
326            openapi_allowed_deletes: HashSet::new(),
327            openapi_blocked_paths: HashSet::new(),
328            openapi_internal_blocked_fields: HashSet::new(),
329            openapi_output_blocked_fields: HashSet::new(),
330            openapi_require_output_declaration: false,
331            // SQL
332            sql_reads_enabled: true,
333            sql_allow_writes: false,
334            sql_allow_deletes: false,
335            sql_allow_ddl: false,
336            sql_allowed_statements: HashSet::new(),
337            sql_blocked_statements: HashSet::new(),
338            sql_blocked_tables: HashSet::new(),
339            sql_allowed_tables: HashSet::new(),
340            sql_blocked_columns: HashSet::new(),
341            sql_max_rows: default_sql_max_rows(),
342            sql_max_joins: default_sql_max_joins(),
343            sql_require_where_on_writes: true,
344            sql_require_limit: false,
345            // Common
346            action_tags: HashMap::new(),
347            max_depth: default_max_depth(),
348            max_field_count: default_max_field_count(),
349            max_cost: default_max_cost(),
350            allowed_sensitive_categories: HashSet::new(),
351            token_ttl_seconds: default_token_ttl(),
352            auto_approve_levels: default_auto_approve_levels(),
353            max_query_length: default_max_query_length(),
354            max_result_rows: default_max_result_rows(),
355            query_timeout_seconds: default_query_timeout(),
356            server_id: None,
357            // SDK
358            sdk_operations: HashSet::new(),
359            operations: Vec::new(),
360        }
361    }
362}
363
364/// Wrapper for deserializing the `[code_mode]` section from a full TOML config file.
365/// The file may contain other sections (`[server]`, `[[tools]]`, etc.) which are ignored.
366#[derive(Deserialize)]
367struct TomlWrapper {
368    #[serde(default)]
369    code_mode: CodeModeConfig,
370}
371
372impl CodeModeConfig {
373    /// Parse `CodeModeConfig` from a full TOML config string.
374    ///
375    /// Extracts the `[code_mode]` section (including `[[code_mode.operations]]`)
376    /// and ignores all other sections. This is the recommended way for external
377    /// servers to build their config from `config.toml`:
378    ///
379    /// ```rust,ignore
380    /// const CONFIG_TOML: &str = include_str!("../../config.toml");
381    ///
382    /// let config = CodeModeConfig::from_toml(CONFIG_TOML)
383    ///     .expect("Invalid code_mode section in config.toml");
384    /// ```
385    ///
386    /// If the TOML has no `[code_mode]` section, returns `CodeModeConfig::default()`.
387    pub fn from_toml(toml_str: &str) -> Result<Self, toml::de::Error> {
388        let wrapper: TomlWrapper = toml::from_str(toml_str)?;
389        Ok(wrapper.code_mode)
390    }
391
392    /// Create a new config with Code Mode enabled.
393    pub fn enabled() -> Self {
394        Self {
395            enabled: true,
396            ..Default::default()
397        }
398    }
399
400    /// Returns true if this config enables SDK-backed Code Mode.
401    pub fn is_sdk_mode(&self) -> bool {
402        !self.sdk_operations.is_empty()
403    }
404
405    /// Check if a risk level should be auto-approved.
406    pub fn should_auto_approve(&self, risk_level: RiskLevel) -> bool {
407        self.auto_approve_levels.contains(&risk_level)
408    }
409
410    /// Get the server ID, falling back to a default.
411    ///
412    /// **Note:** The `"unknown"` fallback produces silent AVP default-deny failures
413    /// (no Cedar policy matches a server_id of `"unknown"`). Prefer
414    /// [`resolve_server_id`](Self::resolve_server_id) to auto-fill from environment,
415    /// or [`require_server_id`](Self::require_server_id) to fail fast.
416    pub fn server_id(&self) -> &str {
417        self.server_id.as_deref().unwrap_or("unknown")
418    }
419
420    /// Auto-resolve `server_id` from environment if not already set.
421    ///
422    /// Resolution order:
423    /// 1. `self.server_id` (if already set, e.g., from TOML) — no change
424    /// 2. `PMCP_SERVER_ID` env var
425    /// 3. `AWS_LAMBDA_FUNCTION_NAME` env var (Lambda runtime)
426    /// 4. Left as `None` — caller is responsible for handling
427    ///
428    /// [`ValidationPipeline`](crate::ValidationPipeline) constructors call this
429    /// automatically, so wrappers rarely need to invoke it directly.
430    pub fn resolve_server_id(&mut self) {
431        if self.server_id.is_some() {
432            return;
433        }
434        self.server_id = resolve_server_id_from_env();
435    }
436
437    /// Return the `server_id`, or an error if not resolved.
438    ///
439    /// Use this in production code paths that require AVP authorization —
440    /// it fails fast with a clear message instead of letting `"unknown"`
441    /// reach AVP and produce a silent default-deny.
442    pub fn require_server_id(&self) -> Result<&str, ValidationError> {
443        self.server_id.as_deref().ok_or_else(|| {
444            ValidationError::ConfigError(
445                "server_id is not set. Set it in config.toml, PMCP_SERVER_ID env var, \
446                 or AWS_LAMBDA_FUNCTION_NAME (Lambda). Without it, AVP authorization \
447                 will default-deny silently."
448                    .into(),
449            )
450        })
451    }
452
453    /// Convert to ServerConfigEntity for policy evaluation.
454    pub fn to_server_config_entity(&self) -> crate::policy::ServerConfigEntity {
455        crate::policy::ServerConfigEntity {
456            server_id: self.server_id().to_string(),
457            server_type: "graphql".to_string(),
458            allow_write: self.allow_mutations,
459            allow_delete: self.allow_mutations,
460            allow_admin: self.allow_introspection,
461            allowed_operations: self.allowed_mutations.clone(),
462            blocked_operations: self.blocked_mutations.clone(),
463            max_depth: self.max_depth,
464            max_field_count: self.max_field_count,
465            max_cost: self.max_cost,
466            max_api_calls: 50,
467            blocked_fields: self.blocked_fields.clone(),
468            allowed_sensitive_categories: self.allowed_sensitive_categories.clone(),
469        }
470    }
471
472    /// Convert to OpenAPIServerEntity for policy evaluation (OpenAPI Code Mode).
473    #[cfg(feature = "openapi-code-mode")]
474    pub fn to_openapi_server_entity(&self) -> crate::policy::OpenAPIServerEntity {
475        let mut allowed_operations = self.openapi_allowed_writes.clone();
476        allowed_operations.extend(self.openapi_allowed_deletes.clone());
477
478        let write_mode = if !self.openapi_allow_writes {
479            "deny_all"
480        } else if !self.openapi_allowed_writes.is_empty() {
481            "allowlist"
482        } else if !self.openapi_blocked_writes.is_empty() {
483            "blocklist"
484        } else {
485            "allow_all"
486        };
487
488        crate::policy::OpenAPIServerEntity {
489            server_id: self.server_id().to_string(),
490            server_type: "openapi".to_string(),
491            allow_write: self.openapi_allow_writes,
492            allow_delete: self.openapi_allow_deletes,
493            allow_admin: false,
494            write_mode: write_mode.to_string(),
495            max_depth: self.max_depth,
496            max_cost: self.max_cost,
497            max_api_calls: 50,
498            max_loop_iterations: 100,
499            max_script_length: self.max_query_length as u32,
500            max_nesting_depth: self.max_depth,
501            execution_timeout_seconds: self.query_timeout_seconds,
502            allowed_operations,
503            blocked_operations: self.openapi_blocked_writes.clone(),
504            allowed_methods: HashSet::new(),
505            blocked_methods: HashSet::new(),
506            allowed_path_patterns: HashSet::new(),
507            blocked_path_patterns: self.openapi_blocked_paths.clone(),
508            sensitive_path_patterns: self.openapi_blocked_paths.clone(),
509            auto_approve_read_only: self.openapi_reads_enabled,
510            max_api_calls_for_auto_approve: 10,
511            internal_blocked_fields: self.openapi_internal_blocked_fields.clone(),
512            output_blocked_fields: self.openapi_output_blocked_fields.clone(),
513            require_output_declaration: self.openapi_require_output_declaration,
514        }
515    }
516
517    /// Convert to `SqlServerEntity` for policy evaluation (SQL Code Mode).
518    #[cfg(feature = "sql-code-mode")]
519    pub fn to_sql_server_entity(&self) -> crate::policy::SqlServerEntity {
520        crate::policy::SqlServerEntity {
521            server_id: self.server_id().to_string(),
522            server_type: "sql".to_string(),
523            allow_write: self.sql_allow_writes,
524            allow_delete: self.sql_allow_deletes,
525            allow_admin: self.sql_allow_ddl,
526            max_rows: self.sql_max_rows,
527            max_joins: self.sql_max_joins,
528            allowed_operations: self.sql_allowed_statements.clone(),
529            blocked_operations: self.sql_blocked_statements.clone(),
530            blocked_tables: self.sql_blocked_tables.clone(),
531            blocked_columns: self.sql_blocked_columns.clone(),
532            allowed_tables: self.sql_allowed_tables.clone(),
533        }
534    }
535}
536
537fn default_true() -> bool {
538    true
539}
540
541fn default_token_ttl() -> i64 {
542    300 // 5 minutes
543}
544
545fn default_auto_approve_levels() -> Vec<RiskLevel> {
546    vec![RiskLevel::Low]
547}
548
549fn default_max_query_length() -> usize {
550    10000
551}
552
553fn default_max_result_rows() -> usize {
554    10000
555}
556
557fn default_query_timeout() -> u32 {
558    30
559}
560
561fn default_max_depth() -> u32 {
562    10
563}
564
565fn default_max_field_count() -> u32 {
566    100
567}
568
569fn default_max_cost() -> u32 {
570    1000
571}
572
573fn default_sql_max_rows() -> u64 {
574    10_000
575}
576
577fn default_sql_max_joins() -> u32 {
578    5
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    #[test]
586    fn test_default_config() {
587        let config = CodeModeConfig::default();
588        assert!(!config.enabled);
589        assert!(!config.allow_mutations);
590        assert_eq!(config.token_ttl_seconds, 300);
591        assert_eq!(config.auto_approve_levels, vec![RiskLevel::Low]);
592    }
593
594    #[test]
595    fn test_enabled_config() {
596        let config = CodeModeConfig::enabled();
597        assert!(config.enabled);
598    }
599
600    #[test]
601    fn test_auto_approve() {
602        let config = CodeModeConfig::default();
603        assert!(config.should_auto_approve(RiskLevel::Low));
604        assert!(!config.should_auto_approve(RiskLevel::Medium));
605        assert!(!config.should_auto_approve(RiskLevel::High));
606        assert!(!config.should_auto_approve(RiskLevel::Critical));
607    }
608
609    #[test]
610    fn test_operation_registry_from_entries() {
611        let entries = vec![
612            OperationEntry {
613                id: "getCostAnomalies".to_string(),
614                category: "read".to_string(),
615                description: "Get cost anomalies".to_string(),
616                path: Some("/getCostAnomalies".to_string()),
617            },
618            OperationEntry {
619                id: "listInstances".to_string(),
620                category: "read".to_string(),
621                description: "List EC2 instances".to_string(),
622                path: Some("/listInstances".to_string()),
623            },
624        ];
625        let registry = OperationRegistry::from_entries(&entries);
626        assert_eq!(
627            registry.lookup("/getCostAnomalies"),
628            Some("getCostAnomalies")
629        );
630        assert_eq!(registry.lookup("/listInstances"), Some("listInstances"));
631    }
632
633    #[test]
634    fn test_operation_registry_lookup_unregistered() {
635        let entries = vec![OperationEntry {
636            id: "getCostAnomalies".to_string(),
637            category: "read".to_string(),
638            description: String::new(),
639            path: Some("/getCostAnomalies".to_string()),
640        }];
641        let registry = OperationRegistry::from_entries(&entries);
642        assert_eq!(registry.lookup("/unknownPath"), None);
643        assert_eq!(registry.lookup(""), None);
644    }
645
646    #[test]
647    fn test_operation_registry_lookup_category() {
648        let entries = vec![
649            OperationEntry {
650                id: "getCostAnomalies".to_string(),
651                category: "read".to_string(),
652                description: String::new(),
653                path: Some("/getCostAnomalies".to_string()),
654            },
655            OperationEntry {
656                id: "deleteReservation".to_string(),
657                category: "delete".to_string(),
658                description: String::new(),
659                path: Some("/deleteReservation".to_string()),
660            },
661            OperationEntry {
662                id: "updateBudget".to_string(),
663                category: "write".to_string(),
664                description: String::new(),
665                path: Some("/updateBudget".to_string()),
666            },
667        ];
668        let registry = OperationRegistry::from_entries(&entries);
669        assert_eq!(registry.lookup_category("/getCostAnomalies"), Some("read"));
670        assert_eq!(
671            registry.lookup_category("/deleteReservation"),
672            Some("delete")
673        );
674        assert_eq!(registry.lookup_category("/updateBudget"), Some("write"));
675        assert_eq!(registry.lookup_category("/unknownPath"), None);
676    }
677
678    #[test]
679    fn test_operation_registry_empty_category_excluded() {
680        let entries = vec![OperationEntry {
681            id: "legacyOp".to_string(),
682            category: String::new(), // empty = not declared
683            description: String::new(),
684            path: Some("/legacyOp".to_string()),
685        }];
686        let registry = OperationRegistry::from_entries(&entries);
687        // ID lookup still works
688        assert_eq!(registry.lookup("/legacyOp"), Some("legacyOp"));
689        // Category lookup returns None for empty category
690        assert_eq!(registry.lookup_category("/legacyOp"), None);
691    }
692
693    #[test]
694    fn test_operation_registry_is_empty() {
695        let empty_registry = OperationRegistry::from_entries(&[]);
696        assert!(empty_registry.is_empty());
697
698        let entries = vec![OperationEntry {
699            id: "op1".to_string(),
700            category: "read".to_string(),
701            description: String::new(),
702            path: Some("/op1".to_string()),
703        }];
704        let registry = OperationRegistry::from_entries(&entries);
705        assert!(!registry.is_empty());
706    }
707
708    #[test]
709    fn test_operation_entry_deserialization() {
710        let toml_str = r#"
711id = "getCostAnomalies"
712category = "read"
713description = "Get cost anomalies"
714path = "/getCostAnomalies"
715"#;
716        let entry: OperationEntry =
717            toml::from_str(toml_str).expect("Failed to deserialize OperationEntry");
718        assert_eq!(entry.id, "getCostAnomalies");
719        assert_eq!(entry.category, "read");
720        assert_eq!(entry.description, "Get cost anomalies");
721        assert_eq!(entry.path, Some("/getCostAnomalies".to_string()));
722    }
723
724    #[test]
725    fn test_code_mode_config_with_operations() {
726        let toml_str = r#"
727enabled = true
728
729[[operations]]
730id = "getCostAnomalies"
731category = "read"
732description = "Get cost anomalies"
733path = "/getCostAnomalies"
734
735[[operations]]
736id = "listInstances"
737category = "read"
738path = "/listInstances"
739"#;
740        let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
741        assert!(config.enabled);
742        assert_eq!(config.operations.len(), 2);
743        assert_eq!(config.operations[0].id, "getCostAnomalies");
744        assert_eq!(config.operations[1].id, "listInstances");
745    }
746
747    #[test]
748    fn test_code_mode_config_without_operations_defaults_to_empty() {
749        let toml_str = r#"
750enabled = true
751"#;
752        let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
753        assert!(config.enabled);
754        assert!(config.operations.is_empty());
755    }
756
757    #[test]
758    fn test_from_toml_extracts_code_mode_section() {
759        let toml_str = r#"
760[server]
761name = "cost-coach"
762type = "openapi-api"
763
764[code_mode]
765enabled = true
766token_ttl_seconds = 600
767server_id = "cost-coach"
768
769[[code_mode.operations]]
770id = "getCostAndUsage"
771category = "read"
772description = "Historical cost and usage data"
773path = "/getCostAndUsage"
774
775[[code_mode.operations]]
776id = "getCostAnomalies"
777category = "read"
778description = "Cost anomalies detected by AWS"
779path = "/getCostAnomalies"
780
781[[tools]]
782name = "some_tool"
783"#;
784        let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
785        assert!(config.enabled);
786        assert_eq!(config.token_ttl_seconds, 600);
787        assert_eq!(config.server_id, Some("cost-coach".to_string()));
788        assert_eq!(config.operations.len(), 2);
789        assert_eq!(config.operations[0].id, "getCostAndUsage");
790        assert_eq!(config.operations[1].id, "getCostAnomalies");
791        assert_eq!(
792            config.operations[0].path,
793            Some("/getCostAndUsage".to_string())
794        );
795    }
796
797    #[test]
798    fn test_from_toml_missing_code_mode_returns_default() {
799        let toml_str = r#"
800[server]
801name = "some-server"
802"#;
803        let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
804        assert!(!config.enabled);
805        assert!(config.operations.is_empty());
806        assert_eq!(config.token_ttl_seconds, 300); // default
807    }
808
809    // =========================================================================
810    // server_id resolution tests
811    //
812    // These tests mutate process-wide env vars. Cargo parallelizes tests across
813    // threads in the same process, so a shared Mutex serializes them — without
814    // this, set_var/remove_var in one test would race with another.
815    // =========================================================================
816
817    use std::sync::Mutex;
818    static ENV_LOCK: Mutex<()> = Mutex::new(());
819
820    struct EnvGuard {
821        _lock: std::sync::MutexGuard<'static, ()>,
822    }
823
824    impl EnvGuard {
825        fn acquire() -> Self {
826            let lock = ENV_LOCK
827                .lock()
828                .unwrap_or_else(|poisoned| poisoned.into_inner());
829            std::env::remove_var("PMCP_SERVER_ID");
830            std::env::remove_var("AWS_LAMBDA_FUNCTION_NAME");
831            Self { _lock: lock }
832        }
833    }
834
835    impl Drop for EnvGuard {
836        fn drop(&mut self) {
837            std::env::remove_var("PMCP_SERVER_ID");
838            std::env::remove_var("AWS_LAMBDA_FUNCTION_NAME");
839        }
840    }
841
842    #[test]
843    fn resolve_server_id_from_explicit_config_takes_precedence() {
844        let _g = EnvGuard::acquire();
845        std::env::set_var("PMCP_SERVER_ID", "from-env");
846
847        let mut config = CodeModeConfig {
848            server_id: Some("from-config".to_string()),
849            ..Default::default()
850        };
851        config.resolve_server_id();
852
853        assert_eq!(config.server_id.as_deref(), Some("from-config"));
854    }
855
856    #[test]
857    fn resolve_server_id_from_pmcp_env() {
858        let _g = EnvGuard::acquire();
859        std::env::set_var("PMCP_SERVER_ID", "my-server");
860
861        let mut config = CodeModeConfig::default();
862        config.resolve_server_id();
863
864        assert_eq!(config.server_id.as_deref(), Some("my-server"));
865    }
866
867    #[test]
868    fn resolve_server_id_from_lambda_env() {
869        let _g = EnvGuard::acquire();
870        std::env::set_var("AWS_LAMBDA_FUNCTION_NAME", "my-lambda-fn");
871
872        let mut config = CodeModeConfig::default();
873        config.resolve_server_id();
874
875        assert_eq!(config.server_id.as_deref(), Some("my-lambda-fn"));
876    }
877
878    #[test]
879    fn resolve_server_id_pmcp_wins_over_lambda() {
880        let _g = EnvGuard::acquire();
881        std::env::set_var("PMCP_SERVER_ID", "explicit");
882        std::env::set_var("AWS_LAMBDA_FUNCTION_NAME", "lambda-fn");
883
884        let mut config = CodeModeConfig::default();
885        config.resolve_server_id();
886
887        assert_eq!(config.server_id.as_deref(), Some("explicit"));
888    }
889
890    #[test]
891    fn resolve_server_id_leaves_none_when_unset() {
892        let _g = EnvGuard::acquire();
893        let mut config = CodeModeConfig::default();
894        config.resolve_server_id();
895        assert!(config.server_id.is_none());
896    }
897
898    #[test]
899    fn require_server_id_errors_when_unset() {
900        let config = CodeModeConfig::default();
901        let result = config.require_server_id();
902        assert!(matches!(result, Err(ValidationError::ConfigError(_))));
903    }
904
905    #[test]
906    fn require_server_id_returns_value_when_set() {
907        let config = CodeModeConfig {
908            server_id: Some("my-server".to_string()),
909            ..Default::default()
910        };
911        assert_eq!(config.require_server_id().unwrap(), "my-server");
912    }
913
914    #[test]
915    fn resolve_server_id_from_env_free_fn_treats_empty_as_unset() {
916        let _g = EnvGuard::acquire();
917        std::env::set_var("PMCP_SERVER_ID", "");
918        assert_eq!(resolve_server_id_from_env(), None);
919    }
920
921    // =========================================================================
922    // SQL TOML DX tests (serde aliases)
923    // =========================================================================
924
925    #[test]
926    fn sql_config_accepts_unprefixed_toml_names() {
927        let toml_str = r#"
928enabled = true
929allow_writes = true
930allow_deletes = true
931allow_ddl = true
932allowed_tables = ["users", "orders"]
933blocked_tables = ["secrets"]
934blocked_columns = ["password", "ssn"]
935max_rows = 5000
936max_joins = 3
937require_where_on_writes = false
938"#;
939        let config: CodeModeConfig =
940            toml::from_str(toml_str).expect("Failed to deserialize with unprefixed aliases");
941
942        assert!(config.enabled);
943        assert!(config.sql_allow_writes);
944        assert!(config.sql_allow_deletes);
945        assert!(config.sql_allow_ddl);
946        assert!(config.sql_allowed_tables.contains("users"));
947        assert!(config.sql_allowed_tables.contains("orders"));
948        assert!(config.sql_blocked_tables.contains("secrets"));
949        assert!(config.sql_blocked_columns.contains("password"));
950        assert_eq!(config.sql_max_rows, 5000);
951        assert_eq!(config.sql_max_joins, 3);
952        assert!(!config.sql_require_where_on_writes);
953    }
954
955    #[test]
956    fn sql_config_accepts_prefixed_toml_names() {
957        let toml_str = r#"
958enabled = true
959sql_allow_writes = true
960sql_blocked_tables = ["secrets"]
961sql_max_rows = 5000
962"#;
963        let config: CodeModeConfig =
964            toml::from_str(toml_str).expect("Failed to deserialize with prefixed names");
965
966        assert!(config.sql_allow_writes);
967        assert!(config.sql_blocked_tables.contains("secrets"));
968        assert_eq!(config.sql_max_rows, 5000);
969    }
970}