1use crate::types::RiskLevel;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct OperationEntry {
12 pub id: String,
15
16 pub category: String,
19
20 #[serde(default)]
22 pub description: String,
23
24 #[serde(default)]
27 pub path: Option<String>,
28}
29
30#[derive(Debug, Clone, Default)]
33pub struct OperationRegistry {
34 path_to_id: HashMap<String, String>,
35 path_to_category: HashMap<String, String>,
36}
37
38impl OperationRegistry {
39 pub fn from_entries(entries: &[OperationEntry]) -> Self {
40 let mut path_to_id = HashMap::with_capacity(entries.len());
41 let mut path_to_category = HashMap::with_capacity(entries.len());
42 for entry in entries {
43 if let Some(ref path) = entry.path {
44 path_to_id.insert(path.clone(), entry.id.clone());
45 if !entry.category.is_empty() {
46 path_to_category.insert(path.clone(), entry.category.clone());
47 }
48 }
49 }
50 Self {
51 path_to_id,
52 path_to_category,
53 }
54 }
55
56 pub fn lookup(&self, path: &str) -> Option<&str> {
57 self.path_to_id.get(path).map(|s| s.as_str())
58 }
59
60 pub fn lookup_category(&self, path: &str) -> Option<&str> {
63 self.path_to_category.get(path).map(|s| s.as_str())
64 }
65
66 pub fn is_empty(&self) -> bool {
67 self.path_to_id.is_empty()
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct CodeModeConfig {
74 #[serde(default)]
76 pub enabled: bool,
77
78 #[serde(default)]
83 pub allow_mutations: bool,
84
85 #[serde(default)]
87 pub allowed_mutations: HashSet<String>,
88
89 #[serde(default)]
91 pub blocked_mutations: HashSet<String>,
92
93 #[serde(default)]
95 pub allow_introspection: bool,
96
97 #[serde(default)]
99 pub blocked_fields: HashSet<String>,
100
101 #[serde(default)]
103 pub allowed_queries: HashSet<String>,
104
105 #[serde(default)]
107 pub blocked_queries: HashSet<String>,
108
109 #[serde(default = "default_true")]
114 pub openapi_reads_enabled: bool,
115
116 #[serde(default)]
118 pub openapi_allow_writes: bool,
119
120 #[serde(default)]
122 pub openapi_allowed_writes: HashSet<String>,
123
124 #[serde(default)]
126 pub openapi_blocked_writes: HashSet<String>,
127
128 #[serde(default)]
130 pub openapi_allow_deletes: bool,
131
132 #[serde(default)]
134 pub openapi_allowed_deletes: HashSet<String>,
135
136 #[serde(default)]
138 pub openapi_blocked_paths: HashSet<String>,
139
140 #[serde(default)]
142 pub openapi_internal_blocked_fields: HashSet<String>,
143
144 #[serde(default)]
146 pub openapi_output_blocked_fields: HashSet<String>,
147
148 #[serde(default)]
150 pub openapi_require_output_declaration: bool,
151
152 #[serde(default)]
157 pub action_tags: HashMap<String, String>,
158
159 #[serde(default = "default_max_depth")]
161 pub max_depth: u32,
162
163 #[serde(default = "default_max_field_count")]
165 pub max_field_count: u32,
166
167 #[serde(default = "default_max_cost")]
169 pub max_cost: u32,
170
171 #[serde(default)]
173 pub allowed_sensitive_categories: HashSet<String>,
174
175 #[serde(default = "default_token_ttl")]
177 pub token_ttl_seconds: i64,
178
179 #[serde(default = "default_auto_approve_levels")]
181 pub auto_approve_levels: Vec<RiskLevel>,
182
183 #[serde(default = "default_max_query_length")]
185 pub max_query_length: usize,
186
187 #[serde(default = "default_max_result_rows")]
189 pub max_result_rows: usize,
190
191 #[serde(default = "default_query_timeout")]
193 pub query_timeout_seconds: u32,
194
195 #[serde(default)]
197 pub server_id: Option<String>,
198
199 #[serde(default)]
206 pub sdk_operations: HashSet<String>,
207
208 #[serde(default)]
213 pub operations: Vec<OperationEntry>,
214}
215
216impl Default for CodeModeConfig {
217 fn default() -> Self {
218 Self {
219 enabled: false,
220 allow_mutations: false,
222 allowed_mutations: HashSet::new(),
223 blocked_mutations: HashSet::new(),
224 allow_introspection: false,
225 blocked_fields: HashSet::new(),
226 allowed_queries: HashSet::new(),
227 blocked_queries: HashSet::new(),
228 openapi_reads_enabled: true,
230 openapi_allow_writes: false,
231 openapi_allowed_writes: HashSet::new(),
232 openapi_blocked_writes: HashSet::new(),
233 openapi_allow_deletes: false,
234 openapi_allowed_deletes: HashSet::new(),
235 openapi_blocked_paths: HashSet::new(),
236 openapi_internal_blocked_fields: HashSet::new(),
237 openapi_output_blocked_fields: HashSet::new(),
238 openapi_require_output_declaration: false,
239 action_tags: HashMap::new(),
241 max_depth: default_max_depth(),
242 max_field_count: default_max_field_count(),
243 max_cost: default_max_cost(),
244 allowed_sensitive_categories: HashSet::new(),
245 token_ttl_seconds: default_token_ttl(),
246 auto_approve_levels: default_auto_approve_levels(),
247 max_query_length: default_max_query_length(),
248 max_result_rows: default_max_result_rows(),
249 query_timeout_seconds: default_query_timeout(),
250 server_id: None,
251 sdk_operations: HashSet::new(),
253 operations: Vec::new(),
254 }
255 }
256}
257
258#[derive(Deserialize)]
261struct TomlWrapper {
262 #[serde(default)]
263 code_mode: CodeModeConfig,
264}
265
266impl CodeModeConfig {
267 pub fn from_toml(toml_str: &str) -> Result<Self, toml::de::Error> {
282 let wrapper: TomlWrapper = toml::from_str(toml_str)?;
283 Ok(wrapper.code_mode)
284 }
285
286 pub fn enabled() -> Self {
288 Self {
289 enabled: true,
290 ..Default::default()
291 }
292 }
293
294 pub fn is_sdk_mode(&self) -> bool {
296 !self.sdk_operations.is_empty()
297 }
298
299 pub fn should_auto_approve(&self, risk_level: RiskLevel) -> bool {
301 self.auto_approve_levels.contains(&risk_level)
302 }
303
304 pub fn server_id(&self) -> &str {
306 self.server_id.as_deref().unwrap_or("unknown")
307 }
308
309 pub fn to_server_config_entity(&self) -> crate::policy::ServerConfigEntity {
311 crate::policy::ServerConfigEntity {
312 server_id: self.server_id().to_string(),
313 server_type: "graphql".to_string(),
314 allow_write: self.allow_mutations,
315 allow_delete: self.allow_mutations,
316 allow_admin: self.allow_introspection,
317 allowed_operations: self.allowed_mutations.clone(),
318 blocked_operations: self.blocked_mutations.clone(),
319 max_depth: self.max_depth,
320 max_field_count: self.max_field_count,
321 max_cost: self.max_cost,
322 max_api_calls: 50,
323 blocked_fields: self.blocked_fields.clone(),
324 allowed_sensitive_categories: self.allowed_sensitive_categories.clone(),
325 }
326 }
327
328 #[cfg(feature = "openapi-code-mode")]
330 pub fn to_openapi_server_entity(&self) -> crate::policy::OpenAPIServerEntity {
331 let mut allowed_operations = self.openapi_allowed_writes.clone();
332 allowed_operations.extend(self.openapi_allowed_deletes.clone());
333
334 let write_mode = if !self.openapi_allow_writes {
335 "deny_all"
336 } else if !self.openapi_allowed_writes.is_empty() {
337 "allowlist"
338 } else if !self.openapi_blocked_writes.is_empty() {
339 "blocklist"
340 } else {
341 "allow_all"
342 };
343
344 crate::policy::OpenAPIServerEntity {
345 server_id: self.server_id().to_string(),
346 server_type: "openapi".to_string(),
347 allow_write: self.openapi_allow_writes,
348 allow_delete: self.openapi_allow_deletes,
349 allow_admin: false,
350 write_mode: write_mode.to_string(),
351 max_depth: self.max_depth,
352 max_cost: self.max_cost,
353 max_api_calls: 50,
354 max_loop_iterations: 100,
355 max_script_length: self.max_query_length as u32,
356 max_nesting_depth: self.max_depth,
357 execution_timeout_seconds: self.query_timeout_seconds,
358 allowed_operations,
359 blocked_operations: self.openapi_blocked_writes.clone(),
360 allowed_methods: HashSet::new(),
361 blocked_methods: HashSet::new(),
362 allowed_path_patterns: HashSet::new(),
363 blocked_path_patterns: self.openapi_blocked_paths.clone(),
364 sensitive_path_patterns: self.openapi_blocked_paths.clone(),
365 auto_approve_read_only: self.openapi_reads_enabled,
366 max_api_calls_for_auto_approve: 10,
367 internal_blocked_fields: self.openapi_internal_blocked_fields.clone(),
368 output_blocked_fields: self.openapi_output_blocked_fields.clone(),
369 require_output_declaration: self.openapi_require_output_declaration,
370 }
371 }
372}
373
374fn default_true() -> bool {
375 true
376}
377
378fn default_token_ttl() -> i64 {
379 300 }
381
382fn default_auto_approve_levels() -> Vec<RiskLevel> {
383 vec![RiskLevel::Low]
384}
385
386fn default_max_query_length() -> usize {
387 10000
388}
389
390fn default_max_result_rows() -> usize {
391 10000
392}
393
394fn default_query_timeout() -> u32 {
395 30
396}
397
398fn default_max_depth() -> u32 {
399 10
400}
401
402fn default_max_field_count() -> u32 {
403 100
404}
405
406fn default_max_cost() -> u32 {
407 1000
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_default_config() {
416 let config = CodeModeConfig::default();
417 assert!(!config.enabled);
418 assert!(!config.allow_mutations);
419 assert_eq!(config.token_ttl_seconds, 300);
420 assert_eq!(config.auto_approve_levels, vec![RiskLevel::Low]);
421 }
422
423 #[test]
424 fn test_enabled_config() {
425 let config = CodeModeConfig::enabled();
426 assert!(config.enabled);
427 }
428
429 #[test]
430 fn test_auto_approve() {
431 let config = CodeModeConfig::default();
432 assert!(config.should_auto_approve(RiskLevel::Low));
433 assert!(!config.should_auto_approve(RiskLevel::Medium));
434 assert!(!config.should_auto_approve(RiskLevel::High));
435 assert!(!config.should_auto_approve(RiskLevel::Critical));
436 }
437
438 #[test]
439 fn test_operation_registry_from_entries() {
440 let entries = vec![
441 OperationEntry {
442 id: "getCostAnomalies".to_string(),
443 category: "read".to_string(),
444 description: "Get cost anomalies".to_string(),
445 path: Some("/getCostAnomalies".to_string()),
446 },
447 OperationEntry {
448 id: "listInstances".to_string(),
449 category: "read".to_string(),
450 description: "List EC2 instances".to_string(),
451 path: Some("/listInstances".to_string()),
452 },
453 ];
454 let registry = OperationRegistry::from_entries(&entries);
455 assert_eq!(
456 registry.lookup("/getCostAnomalies"),
457 Some("getCostAnomalies")
458 );
459 assert_eq!(registry.lookup("/listInstances"), Some("listInstances"));
460 }
461
462 #[test]
463 fn test_operation_registry_lookup_unregistered() {
464 let entries = vec![OperationEntry {
465 id: "getCostAnomalies".to_string(),
466 category: "read".to_string(),
467 description: String::new(),
468 path: Some("/getCostAnomalies".to_string()),
469 }];
470 let registry = OperationRegistry::from_entries(&entries);
471 assert_eq!(registry.lookup("/unknownPath"), None);
472 assert_eq!(registry.lookup(""), None);
473 }
474
475 #[test]
476 fn test_operation_registry_lookup_category() {
477 let entries = vec![
478 OperationEntry {
479 id: "getCostAnomalies".to_string(),
480 category: "read".to_string(),
481 description: String::new(),
482 path: Some("/getCostAnomalies".to_string()),
483 },
484 OperationEntry {
485 id: "deleteReservation".to_string(),
486 category: "delete".to_string(),
487 description: String::new(),
488 path: Some("/deleteReservation".to_string()),
489 },
490 OperationEntry {
491 id: "updateBudget".to_string(),
492 category: "write".to_string(),
493 description: String::new(),
494 path: Some("/updateBudget".to_string()),
495 },
496 ];
497 let registry = OperationRegistry::from_entries(&entries);
498 assert_eq!(registry.lookup_category("/getCostAnomalies"), Some("read"));
499 assert_eq!(
500 registry.lookup_category("/deleteReservation"),
501 Some("delete")
502 );
503 assert_eq!(registry.lookup_category("/updateBudget"), Some("write"));
504 assert_eq!(registry.lookup_category("/unknownPath"), None);
505 }
506
507 #[test]
508 fn test_operation_registry_empty_category_excluded() {
509 let entries = vec![OperationEntry {
510 id: "legacyOp".to_string(),
511 category: String::new(), description: String::new(),
513 path: Some("/legacyOp".to_string()),
514 }];
515 let registry = OperationRegistry::from_entries(&entries);
516 assert_eq!(registry.lookup("/legacyOp"), Some("legacyOp"));
518 assert_eq!(registry.lookup_category("/legacyOp"), None);
520 }
521
522 #[test]
523 fn test_operation_registry_is_empty() {
524 let empty_registry = OperationRegistry::from_entries(&[]);
525 assert!(empty_registry.is_empty());
526
527 let entries = vec![OperationEntry {
528 id: "op1".to_string(),
529 category: "read".to_string(),
530 description: String::new(),
531 path: Some("/op1".to_string()),
532 }];
533 let registry = OperationRegistry::from_entries(&entries);
534 assert!(!registry.is_empty());
535 }
536
537 #[test]
538 fn test_operation_entry_deserialization() {
539 let toml_str = r#"
540id = "getCostAnomalies"
541category = "read"
542description = "Get cost anomalies"
543path = "/getCostAnomalies"
544"#;
545 let entry: OperationEntry =
546 toml::from_str(toml_str).expect("Failed to deserialize OperationEntry");
547 assert_eq!(entry.id, "getCostAnomalies");
548 assert_eq!(entry.category, "read");
549 assert_eq!(entry.description, "Get cost anomalies");
550 assert_eq!(entry.path, Some("/getCostAnomalies".to_string()));
551 }
552
553 #[test]
554 fn test_code_mode_config_with_operations() {
555 let toml_str = r#"
556enabled = true
557
558[[operations]]
559id = "getCostAnomalies"
560category = "read"
561description = "Get cost anomalies"
562path = "/getCostAnomalies"
563
564[[operations]]
565id = "listInstances"
566category = "read"
567path = "/listInstances"
568"#;
569 let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
570 assert!(config.enabled);
571 assert_eq!(config.operations.len(), 2);
572 assert_eq!(config.operations[0].id, "getCostAnomalies");
573 assert_eq!(config.operations[1].id, "listInstances");
574 }
575
576 #[test]
577 fn test_code_mode_config_without_operations_defaults_to_empty() {
578 let toml_str = r#"
579enabled = true
580"#;
581 let config: CodeModeConfig = toml::from_str(toml_str).expect("Failed to deserialize");
582 assert!(config.enabled);
583 assert!(config.operations.is_empty());
584 }
585
586 #[test]
587 fn test_from_toml_extracts_code_mode_section() {
588 let toml_str = r#"
589[server]
590name = "cost-coach"
591type = "openapi-api"
592
593[code_mode]
594enabled = true
595token_ttl_seconds = 600
596server_id = "cost-coach"
597
598[[code_mode.operations]]
599id = "getCostAndUsage"
600category = "read"
601description = "Historical cost and usage data"
602path = "/getCostAndUsage"
603
604[[code_mode.operations]]
605id = "getCostAnomalies"
606category = "read"
607description = "Cost anomalies detected by AWS"
608path = "/getCostAnomalies"
609
610[[tools]]
611name = "some_tool"
612"#;
613 let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
614 assert!(config.enabled);
615 assert_eq!(config.token_ttl_seconds, 600);
616 assert_eq!(config.server_id, Some("cost-coach".to_string()));
617 assert_eq!(config.operations.len(), 2);
618 assert_eq!(config.operations[0].id, "getCostAndUsage");
619 assert_eq!(config.operations[1].id, "getCostAnomalies");
620 assert_eq!(
621 config.operations[0].path,
622 Some("/getCostAndUsage".to_string())
623 );
624 }
625
626 #[test]
627 fn test_from_toml_missing_code_mode_returns_default() {
628 let toml_str = r#"
629[server]
630name = "some-server"
631"#;
632 let config = CodeModeConfig::from_toml(toml_str).expect("Failed to parse");
633 assert!(!config.enabled);
634 assert!(config.operations.is_empty());
635 assert_eq!(config.token_ttl_seconds, 300); }
637}