hdbconnect_mcp/security/
schema_filter.rs1use std::collections::HashSet;
4
5use crate::Error;
6
7#[derive(Debug, Clone, Default)]
9pub enum SchemaFilter {
10 #[default]
12 AllowAll,
13 Whitelist(HashSet<String>),
15 Blacklist(HashSet<String>),
17}
18
19impl SchemaFilter {
20 pub fn is_allowed(&self, schema: &str) -> bool {
22 let schema_upper = schema.to_uppercase();
23 match self {
24 Self::AllowAll => true,
25 Self::Whitelist(allowed) => allowed.contains(&schema_upper),
26 Self::Blacklist(denied) => !denied.contains(&schema_upper),
27 }
28 }
29
30 pub fn validate(&self, schema: &str) -> Result<(), Error> {
32 if self.is_allowed(schema) {
33 Ok(())
34 } else {
35 Err(Error::SchemaAccessDenied(schema.to_string()))
36 }
37 }
38
39 pub fn from_config(mode: &str, schemas: &[String]) -> Result<Self, Error> {
41 let schemas_set: HashSet<String> = schemas.iter().map(|s| s.to_uppercase()).collect();
42
43 match mode.to_lowercase().as_str() {
44 "whitelist" | "allow" => {
45 if schemas_set.is_empty() {
46 return Err(Error::Config(
47 "Whitelist mode requires at least one schema".into(),
48 ));
49 }
50 Ok(Self::Whitelist(schemas_set))
51 }
52 "blacklist" | "deny" => Ok(Self::Blacklist(schemas_set)),
53 "none" | "all" | "" => Ok(Self::AllowAll),
54 _ => Err(Error::Config(format!(
55 "Invalid schema filter mode: {mode}. Use 'whitelist', 'blacklist', or 'none'"
56 ))),
57 }
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn test_allow_all() {
67 let filter = SchemaFilter::AllowAll;
68 assert!(filter.is_allowed("ANY_SCHEMA"));
69 assert!(filter.is_allowed("SYS"));
70 assert!(filter.is_allowed("system"));
71 }
72
73 #[test]
74 fn test_whitelist() {
75 let allowed: HashSet<String> = ["ALLOWED_SCHEMA", "APP"]
76 .iter()
77 .map(|s| (*s).to_string())
78 .collect();
79 let filter = SchemaFilter::Whitelist(allowed);
80
81 assert!(filter.is_allowed("ALLOWED_SCHEMA"));
82 assert!(filter.is_allowed("allowed_schema")); assert!(filter.is_allowed("APP"));
84 assert!(!filter.is_allowed("OTHER"));
85 assert!(!filter.is_allowed("SYS"));
86 }
87
88 #[test]
89 fn test_blacklist() {
90 let denied: HashSet<String> = ["SYS", "SYSTEM"].iter().map(|s| (*s).to_string()).collect();
91 let filter = SchemaFilter::Blacklist(denied);
92
93 assert!(!filter.is_allowed("SYS"));
94 assert!(!filter.is_allowed("sys")); assert!(!filter.is_allowed("SYSTEM"));
96 assert!(filter.is_allowed("APP"));
97 assert!(filter.is_allowed("MY_SCHEMA"));
98 }
99
100 #[test]
101 fn test_from_config_whitelist() {
102 let schemas = vec!["SCHEMA1".to_string(), "SCHEMA2".to_string()];
103 let filter = SchemaFilter::from_config("whitelist", &schemas).unwrap();
104
105 assert!(filter.is_allowed("SCHEMA1"));
106 assert!(filter.is_allowed("schema2"));
107 assert!(!filter.is_allowed("OTHER"));
108 }
109
110 #[test]
111 fn test_from_config_blacklist() {
112 let schemas = vec!["SYS".to_string()];
113 let filter = SchemaFilter::from_config("blacklist", &schemas).unwrap();
114
115 assert!(!filter.is_allowed("SYS"));
116 assert!(filter.is_allowed("APP"));
117 }
118
119 #[test]
120 fn test_from_config_none() {
121 let filter = SchemaFilter::from_config("none", &[]).unwrap();
122 assert!(filter.is_allowed("ANY"));
123 }
124
125 #[test]
126 fn test_from_config_whitelist_requires_schemas() {
127 let result = SchemaFilter::from_config("whitelist", &[]);
128 assert!(result.is_err());
129 }
130
131 #[test]
132 fn test_from_config_invalid_mode() {
133 let result = SchemaFilter::from_config("invalid", &[]);
134 assert!(result.is_err());
135 }
136
137 #[test]
138 fn test_validate() {
139 let denied: HashSet<String> = ["SYS"].iter().map(|s| (*s).to_string()).collect();
140 let filter = SchemaFilter::Blacklist(denied);
141
142 assert!(filter.validate("APP").is_ok());
143 assert!(filter.validate("SYS").is_err());
144 }
145}