Skip to main content

pmcp_code_mode/
config.rs

1//! Code Mode configuration.
2
3use crate::types::RiskLevel;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::collections::HashSet;
7
8/// A single declared operation in Code Mode configuration.
9/// Maps a raw API path to a canonical plain-name ID for Cedar policies.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct OperationEntry {
12    /// Canonical operation ID (plain name, no method prefix).
13    /// This is what appears in Cedar policy calledOperations.
14    pub id: String,
15
16    /// Action category for AVP action routing.
17    /// Values: "read", "write", "delete", "admin"
18    pub category: String,
19
20    /// Human-readable description for admin UI and LLM context.
21    #[serde(default)]
22    pub description: String,
23
24    /// Raw API path this ID maps to (e.g., "/getCostAnomalies").
25    /// Used to match against api_call.path from JavaScript analysis.
26    #[serde(default)]
27    pub path: Option<String>,
28}
29
30/// Registry built from [[code_mode.operations]] config entries.
31/// Maps raw paths to canonical operation IDs and categories.
32#[derive(Debug, Clone, Default)]
33pub struct OperationRegistry {
34    path_to_id: HashMap<String, String>,
35    path_to_category: HashMap<String, String>,
36}
37
38impl OperationRegistry {
39    pub fn from_entries(entries: &[OperationEntry]) -> Self {
40        let mut path_to_id = HashMap::with_capacity(entries.len());
41        let mut path_to_category = HashMap::with_capacity(entries.len());
42        for entry in entries {
43            if let Some(ref path) = entry.path {
44                path_to_id.insert(path.clone(), entry.id.clone());
45                if !entry.category.is_empty() {
46                    path_to_category.insert(path.clone(), entry.category.clone());
47                }
48            }
49        }
50        Self {
51            path_to_id,
52            path_to_category,
53        }
54    }
55
56    pub fn lookup(&self, path: &str) -> Option<&str> {
57        self.path_to_id.get(path).map(|s| s.as_str())
58    }
59
60    /// Look up the declared category for a path (e.g., "read", "write", "delete", "admin").
61    /// Returns `None` if the path has no registry entry or no category declared.
62    pub fn lookup_category(&self, path: &str) -> Option<&str> {
63        self.path_to_category.get(path).map(|s| s.as_str())
64    }
65
66    pub fn is_empty(&self) -> bool {
67        self.path_to_id.is_empty()
68    }
69}
70
71/// Configuration for Code Mode.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct CodeModeConfig {
74    /// Whether Code Mode is enabled for this server
75    #[serde(default)]
76    pub enabled: bool,
77
78    // ========================================================================
79    // GraphQL-specific settings
80    // ========================================================================
81    /// Whether to allow mutations (MVP: false)
82    #[serde(default)]
83    pub allow_mutations: bool,
84
85    /// Allowed mutation names (whitelist). If empty and allow_mutations=true, all are allowed.
86    #[serde(default)]
87    pub allowed_mutations: HashSet<String>,
88
89    /// Blocked mutation names (blacklist). Always blocked even if allow_mutations=true.
90    #[serde(default)]
91    pub blocked_mutations: HashSet<String>,
92
93    /// Whether to allow introspection queries
94    #[serde(default)]
95    pub allow_introspection: bool,
96
97    /// Fields that should never be returned (Type.field format) - GraphQL
98    #[serde(default)]
99    pub blocked_fields: HashSet<String>,
100
101    /// Allowed query names (whitelist). If empty and mode is allowlist, none are allowed.
102    #[serde(default)]
103    pub allowed_queries: HashSet<String>,
104
105    /// Blocked query names (blocklist). Always blocked even if reads enabled.
106    #[serde(default)]
107    pub blocked_queries: HashSet<String>,
108
109    // ========================================================================
110    // OpenAPI-specific settings
111    // ========================================================================
112    /// Whether read operations (GET) are enabled (default: true)
113    #[serde(default = "default_true")]
114    pub openapi_reads_enabled: bool,
115
116    /// Whether write operations (POST, PUT, PATCH) are allowed globally
117    #[serde(default)]
118    pub openapi_allow_writes: bool,
119
120    /// Allowed write operations (operationId or "METHOD /path")
121    #[serde(default)]
122    pub openapi_allowed_writes: HashSet<String>,
123
124    /// Blocked write operations
125    #[serde(default)]
126    pub openapi_blocked_writes: HashSet<String>,
127
128    /// Whether delete operations (DELETE) are allowed globally
129    #[serde(default)]
130    pub openapi_allow_deletes: bool,
131
132    /// Allowed delete operations (operationId or "METHOD /path")
133    #[serde(default)]
134    pub openapi_allowed_deletes: HashSet<String>,
135
136    /// Blocked paths (glob patterns like "/admin/*")
137    #[serde(default)]
138    pub openapi_blocked_paths: HashSet<String>,
139
140    /// Fields that are stripped from API responses entirely (no access)
141    #[serde(default)]
142    pub openapi_internal_blocked_fields: HashSet<String>,
143
144    /// Fields that can be used internally but not in script output
145    #[serde(default)]
146    pub openapi_output_blocked_fields: HashSet<String>,
147
148    /// Whether scripts must declare their return type with @returns
149    #[serde(default)]
150    pub openapi_require_output_declaration: bool,
151
152    // ========================================================================
153    // Common settings
154    // ========================================================================
155    /// Action tags to override inferred actions for specific operations.
156    #[serde(default)]
157    pub action_tags: HashMap<String, String>,
158
159    /// Maximum query depth
160    #[serde(default = "default_max_depth")]
161    pub max_depth: u32,
162
163    /// Maximum field count per query
164    #[serde(default = "default_max_field_count")]
165    pub max_field_count: u32,
166
167    /// Maximum estimated query cost
168    #[serde(default = "default_max_cost")]
169    pub max_cost: u32,
170
171    /// Allowed sensitive data categories
172    #[serde(default)]
173    pub allowed_sensitive_categories: HashSet<String>,
174
175    /// Token time-to-live in seconds
176    #[serde(default = "default_token_ttl")]
177    pub token_ttl_seconds: i64,
178
179    /// Risk levels that can be auto-approved without human confirmation
180    #[serde(default = "default_auto_approve_levels")]
181    pub auto_approve_levels: Vec<RiskLevel>,
182
183    /// Maximum query length in characters
184    #[serde(default = "default_max_query_length")]
185    pub max_query_length: usize,
186
187    /// Maximum result rows to return
188    #[serde(default = "default_max_result_rows")]
189    pub max_result_rows: usize,
190
191    /// Query execution timeout in seconds
192    #[serde(default = "default_query_timeout")]
193    pub query_timeout_seconds: u32,
194
195    /// Server ID for token generation
196    #[serde(default)]
197    pub server_id: Option<String>,
198
199    // ========================================================================
200    // SDK-backed settings
201    // ========================================================================
202    /// Allowed SDK operation names for SDK-backed Code Mode.
203    /// When non-empty, Code Mode uses SDK dispatch instead of HTTP.
204    /// Operations are validated at compile time — unlisted names are rejected.
205    #[serde(default)]
206    pub sdk_operations: HashSet<String>,
207
208    /// Declared operations for plain-name ID mapping in Cedar entities.
209    /// Parsed from [[code_mode.operations]] TOML sections.
210    /// When non-empty, ScriptEntity calledOperations uses IDs from the registry
211    /// built from these entries. Unregistered paths fall back to METHOD:/path.
212    #[serde(default)]
213    pub operations: Vec<OperationEntry>,
214}
215
216impl Default for CodeModeConfig {
217    fn default() -> Self {
218        Self {
219            enabled: false,
220            // GraphQL
221            allow_mutations: false,
222            allowed_mutations: HashSet::new(),
223            blocked_mutations: HashSet::new(),
224            allow_introspection: false,
225            blocked_fields: HashSet::new(),
226            allowed_queries: HashSet::new(),
227            blocked_queries: HashSet::new(),
228            // OpenAPI
229            openapi_reads_enabled: true,
230            openapi_allow_writes: false,
231            openapi_allowed_writes: HashSet::new(),
232            openapi_blocked_writes: HashSet::new(),
233            openapi_allow_deletes: false,
234            openapi_allowed_deletes: HashSet::new(),
235            openapi_blocked_paths: HashSet::new(),
236            openapi_internal_blocked_fields: HashSet::new(),
237            openapi_output_blocked_fields: HashSet::new(),
238            openapi_require_output_declaration: false,
239            // Common
240            action_tags: HashMap::new(),
241            max_depth: default_max_depth(),
242            max_field_count: default_max_field_count(),
243            max_cost: default_max_cost(),
244            allowed_sensitive_categories: HashSet::new(),
245            token_ttl_seconds: default_token_ttl(),
246            auto_approve_levels: default_auto_approve_levels(),
247            max_query_length: default_max_query_length(),
248            max_result_rows: default_max_result_rows(),
249            query_timeout_seconds: default_query_timeout(),
250            server_id: None,
251            // SDK
252            sdk_operations: HashSet::new(),
253            operations: Vec::new(),
254        }
255    }
256}
257
258/// Wrapper for deserializing the `[code_mode]` section from a full TOML config file.
259/// The file may contain other sections (`[server]`, `[[tools]]`, etc.) which are ignored.
260#[derive(Deserialize)]
261struct TomlWrapper {
262    #[serde(default)]
263    code_mode: CodeModeConfig,
264}
265
266impl CodeModeConfig {
267    /// Parse `CodeModeConfig` from a full TOML config string.
268    ///
269    /// Extracts the `[code_mode]` section (including `[[code_mode.operations]]`)
270    /// and ignores all other sections. This is the recommended way for external
271    /// servers to build their config from `config.toml`:
272    ///
273    /// ```rust,ignore
274    /// const CONFIG_TOML: &str = include_str!("../../config.toml");
275    ///
276    /// let config = CodeModeConfig::from_toml(CONFIG_TOML)
277    ///     .expect("Invalid code_mode section in config.toml");
278    /// ```
279    ///
280    /// If the TOML has no `[code_mode]` section, returns `CodeModeConfig::default()`.
281    pub fn from_toml(toml_str: &str) -> Result<Self, toml::de::Error> {
282        let wrapper: TomlWrapper = toml::from_str(toml_str)?;
283        Ok(wrapper.code_mode)
284    }
285
286    /// Create a new config with Code Mode enabled.
287    pub fn enabled() -> Self {
288        Self {
289            enabled: true,
290            ..Default::default()
291        }
292    }
293
294    /// Returns true if this config enables SDK-backed Code Mode.
295    pub fn is_sdk_mode(&self) -> bool {
296        !self.sdk_operations.is_empty()
297    }
298
299    /// Check if a risk level should be auto-approved.
300    pub fn should_auto_approve(&self, risk_level: RiskLevel) -> bool {
301        self.auto_approve_levels.contains(&risk_level)
302    }
303
304    /// Get the server ID, falling back to a default.
305    pub fn server_id(&self) -> &str {
306        self.server_id.as_deref().unwrap_or("unknown")
307    }
308
309    /// Convert to ServerConfigEntity for policy evaluation.
310    pub fn to_server_config_entity(&self) -> crate::policy::ServerConfigEntity {
311        crate::policy::ServerConfigEntity {
312            server_id: self.server_id().to_string(),
313            server_type: "graphql".to_string(),
314            allow_write: self.allow_mutations,
315            allow_delete: self.allow_mutations,
316            allow_admin: self.allow_introspection,
317            allowed_operations: self.allowed_mutations.clone(),
318            blocked_operations: self.blocked_mutations.clone(),
319            max_depth: self.max_depth,
320            max_field_count: self.max_field_count,
321            max_cost: self.max_cost,
322            max_api_calls: 50,
323            blocked_fields: self.blocked_fields.clone(),
324            allowed_sensitive_categories: self.allowed_sensitive_categories.clone(),
325        }
326    }
327
328    /// Convert to OpenAPIServerEntity for policy evaluation (OpenAPI Code Mode).
329    #[cfg(feature = "openapi-code-mode")]
330    pub fn to_openapi_server_entity(&self) -> crate::policy::OpenAPIServerEntity {
331        let mut allowed_operations = self.openapi_allowed_writes.clone();
332        allowed_operations.extend(self.openapi_allowed_deletes.clone());
333
334        let write_mode = if !self.openapi_allow_writes {
335            "deny_all"
336        } else if !self.openapi_allowed_writes.is_empty() {
337            "allowlist"
338        } else if !self.openapi_blocked_writes.is_empty() {
339            "blocklist"
340        } else {
341            "allow_all"
342        };
343
344        crate::policy::OpenAPIServerEntity {
345            server_id: self.server_id().to_string(),
346            server_type: "openapi".to_string(),
347            allow_write: self.openapi_allow_writes,
348            allow_delete: self.openapi_allow_deletes,
349            allow_admin: false,
350            write_mode: write_mode.to_string(),
351            max_depth: self.max_depth,
352            max_cost: self.max_cost,
353            max_api_calls: 50,
354            max_loop_iterations: 100,
355            max_script_length: self.max_query_length as u32,
356            max_nesting_depth: self.max_depth,
357            execution_timeout_seconds: self.query_timeout_seconds,
358            allowed_operations,
359            blocked_operations: self.openapi_blocked_writes.clone(),
360            allowed_methods: HashSet::new(),
361            blocked_methods: HashSet::new(),
362            allowed_path_patterns: HashSet::new(),
363            blocked_path_patterns: self.openapi_blocked_paths.clone(),
364            sensitive_path_patterns: self.openapi_blocked_paths.clone(),
365            auto_approve_read_only: self.openapi_reads_enabled,
366            max_api_calls_for_auto_approve: 10,
367            internal_blocked_fields: self.openapi_internal_blocked_fields.clone(),
368            output_blocked_fields: self.openapi_output_blocked_fields.clone(),
369            require_output_declaration: self.openapi_require_output_declaration,
370        }
371    }
372}
373
374fn default_true() -> bool {
375    true
376}
377
378fn default_token_ttl() -> i64 {
379    300 // 5 minutes
380}
381
382fn default_auto_approve_levels() -> Vec<RiskLevel> {
383    vec![RiskLevel::Low]
384}
385
386fn default_max_query_length() -> usize {
387    10000
388}
389
390fn default_max_result_rows() -> usize {
391    10000
392}
393
394fn default_query_timeout() -> u32 {
395    30
396}
397
398fn default_max_depth() -> u32 {
399    10
400}
401
402fn default_max_field_count() -> u32 {
403    100
404}
405
406fn default_max_cost() -> u32 {
407    1000
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_default_config() {
416        let config = CodeModeConfig::default();
417        assert!(!config.enabled);
418        assert!(!config.allow_mutations);
419        assert_eq!(config.token_ttl_seconds, 300);
420        assert_eq!(config.auto_approve_levels, vec![RiskLevel::Low]);
421    }
422
423    #[test]
424    fn test_enabled_config() {
425        let config = CodeModeConfig::enabled();
426        assert!(config.enabled);
427    }
428
429    #[test]
430    fn test_auto_approve() {
431        let config = CodeModeConfig::default();
432        assert!(config.should_auto_approve(RiskLevel::Low));
433        assert!(!config.should_auto_approve(RiskLevel::Medium));
434        assert!(!config.should_auto_approve(RiskLevel::High));
435        assert!(!config.should_auto_approve(RiskLevel::Critical));
436    }
437
438    #[test]
439    fn test_operation_registry_from_entries() {
440        let entries = vec![
441            OperationEntry {
442                id: "getCostAnomalies".to_string(),
443                category: "read".to_string(),
444                description: "Get cost anomalies".to_string(),
445                path: Some("/getCostAnomalies".to_string()),
446            },
447            OperationEntry {
448                id: "listInstances".to_string(),
449                category: "read".to_string(),
450                description: "List EC2 instances".to_string(),
451                path: Some("/listInstances".to_string()),
452            },
453        ];
454        let registry = OperationRegistry::from_entries(&entries);
455        assert_eq!(
456            registry.lookup("/getCostAnomalies"),
457            Some("getCostAnomalies")
458        );
459        assert_eq!(registry.lookup("/listInstances"), Some("listInstances"));
460    }
461
462    #[test]
463    fn test_operation_registry_lookup_unregistered() {
464        let entries = vec![OperationEntry {
465            id: "getCostAnomalies".to_string(),
466            category: "read".to_string(),
467            description: String::new(),
468            path: Some("/getCostAnomalies".to_string()),
469        }];
470        let registry = OperationRegistry::from_entries(&entries);
471        assert_eq!(registry.lookup("/unknownPath"), None);
472        assert_eq!(registry.lookup(""), None);
473    }
474
475    #[test]
476    fn test_operation_registry_lookup_category() {
477        let entries = vec![
478            OperationEntry {
479                id: "getCostAnomalies".to_string(),
480                category: "read".to_string(),
481                description: String::new(),
482                path: Some("/getCostAnomalies".to_string()),
483            },
484            OperationEntry {
485                id: "deleteReservation".to_string(),
486                category: "delete".to_string(),
487                description: String::new(),
488                path: Some("/deleteReservation".to_string()),
489            },
490            OperationEntry {
491                id: "updateBudget".to_string(),
492                category: "write".to_string(),
493                description: String::new(),
494                path: Some("/updateBudget".to_string()),
495            },
496        ];
497        let registry = OperationRegistry::from_entries(&entries);
498        assert_eq!(registry.lookup_category("/getCostAnomalies"), Some("read"));
499        assert_eq!(
500            registry.lookup_category("/deleteReservation"),
501            Some("delete")
502        );
503        assert_eq!(registry.lookup_category("/updateBudget"), Some("write"));
504        assert_eq!(registry.lookup_category("/unknownPath"), None);
505    }
506
507    #[test]
508    fn test_operation_registry_empty_category_excluded() {
509        let entries = vec![OperationEntry {
510            id: "legacyOp".to_string(),
511            category: String::new(), // empty = not declared
512            description: String::new(),
513            path: Some("/legacyOp".to_string()),
514        }];
515        let registry = OperationRegistry::from_entries(&entries);
516        // ID lookup still works
517        assert_eq!(registry.lookup("/legacyOp"), Some("legacyOp"));
518        // Category lookup returns None for empty category
519        assert_eq!(registry.lookup_category("/legacyOp"), None);
520    }
521
522    #[test]
523    fn test_operation_registry_is_empty() {
524        let empty_registry = OperationRegistry::from_entries(&[]);
525        assert!(empty_registry.is_empty());
526
527        let entries = vec![OperationEntry {
528            id: "op1".to_string(),
529            category: "read".to_string(),
530            description: String::new(),
531            path: Some("/op1".to_string()),
532        }];
533        let registry = OperationRegistry::from_entries(&entries);
534        assert!(!registry.is_empty());
535    }
536
537    #[test]
538    fn test_operation_entry_deserialization() {
539        let toml_str = r#"
540id = "getCostAnomalies"
541category = "read"
542description = "Get cost anomalies"
543path = "/getCostAnomalies"
544"#;
545        let entry: OperationEntry =
546            toml::from_str(toml_str).expect("Failed to deserialize OperationEntry");
547        assert_eq!(entry.id, "getCostAnomalies");
548        assert_eq!(entry.category, "read");
549        assert_eq!(entry.description, "Get cost anomalies");
550        assert_eq!(entry.path, Some("/getCostAnomalies".to_string()));
551    }
552
553    #[test]
554    fn test_code_mode_config_with_operations() {
555        let toml_str = r#"
556enabled = true
557
558[[operations]]
559id = "getCostAnomalies"
560category = "read"
561description = "Get cost anomalies"
562path = "/getCostAnomalies"
563
564[[operations]]
565id = "listInstances"
566category = "read"
567path = "/listInstances"
568"#;
569        let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
570        assert!(config.enabled);
571        assert_eq!(config.operations.len(), 2);
572        assert_eq!(config.operations[0].id, "getCostAnomalies");
573        assert_eq!(config.operations[1].id, "listInstances");
574    }
575
576    #[test]
577    fn test_code_mode_config_without_operations_defaults_to_empty() {
578        let toml_str = r#"
579enabled = true
580"#;
581        let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
582        assert!(config.enabled);
583        assert!(config.operations.is_empty());
584    }
585
586    #[test]
587    fn test_from_toml_extracts_code_mode_section() {
588        let toml_str = r#"
589[server]
590name = "cost-coach"
591type = "openapi-api"
592
593[code_mode]
594enabled = true
595token_ttl_seconds = 600
596server_id = "cost-coach"
597
598[[code_mode.operations]]
599id = "getCostAndUsage"
600category = "read"
601description = "Historical cost and usage data"
602path = "/getCostAndUsage"
603
604[[code_mode.operations]]
605id = "getCostAnomalies"
606category = "read"
607description = "Cost anomalies detected by AWS"
608path = "/getCostAnomalies"
609
610[[tools]]
611name = "some_tool"
612"#;
613        let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
614        assert!(config.enabled);
615        assert_eq!(config.token_ttl_seconds, 600);
616        assert_eq!(config.server_id, Some("cost-coach".to_string()));
617        assert_eq!(config.operations.len(), 2);
618        assert_eq!(config.operations[0].id, "getCostAndUsage");
619        assert_eq!(config.operations[1].id, "getCostAnomalies");
620        assert_eq!(
621            config.operations[0].path,
622            Some("/getCostAndUsage".to_string())
623        );
624    }
625
626    #[test]
627    fn test_from_toml_missing_code_mode_returns_default() {
628        let toml_str = r#"
629[server]
630name = "some-server"
631"#;
632        let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
633        assert!(!config.enabled);
634        assert!(config.operations.is_empty());
635        assert_eq!(config.token_ttl_seconds, 300); // default
636    }
637}