Skip to main content

hdbconnect_mcp/security/
schema_filter.rs

1//! Schema access filtering
2
3use std::collections::HashSet;
4
5use crate::Error;
6
7/// Schema access filter configuration
8#[derive(Debug, Clone, Default)]
9pub enum SchemaFilter {
10    /// Allow access to all schemas (default, backward compatible)
11    #[default]
12    AllowAll,
13    /// Only allow access to specified schemas
14    Whitelist(HashSet<String>),
15    /// Deny access to specified schemas, allow all others
16    Blacklist(HashSet<String>),
17}
18
19impl SchemaFilter {
20    /// Check if access to a schema is allowed
21    pub fn is_allowed(&self, schema: &str) -> bool {
22        let schema_upper = schema.to_uppercase();
23        match self {
24            Self::AllowAll => true,
25            Self::Whitelist(allowed) => allowed.contains(&schema_upper),
26            Self::Blacklist(denied) => !denied.contains(&schema_upper),
27        }
28    }
29
30    /// Validate schema access, returning an error if denied
31    pub fn validate(&self, schema: &str) -> Result<(), Error> {
32        if self.is_allowed(schema) {
33            Ok(())
34        } else {
35            Err(Error::SchemaAccessDenied(schema.to_string()))
36        }
37    }
38
39    /// Create a filter from configuration strings
40    pub fn from_config(mode: &str, schemas: &[String]) -> Result<Self, Error> {
41        let schemas_set: HashSet<String> = schemas.iter().map(|s| s.to_uppercase()).collect();
42
43        match mode.to_lowercase().as_str() {
44            "whitelist" | "allow" => {
45                if schemas_set.is_empty() {
46                    return Err(Error::Config(
47                        "Whitelist mode requires at least one schema".into(),
48                    ));
49                }
50                Ok(Self::Whitelist(schemas_set))
51            }
52            "blacklist" | "deny" => Ok(Self::Blacklist(schemas_set)),
53            "none" | "all" | "" => Ok(Self::AllowAll),
54            _ => Err(Error::Config(format!(
55                "Invalid schema filter mode: {mode}. Use 'whitelist', 'blacklist', or 'none'"
56            ))),
57        }
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn test_allow_all() {
67        let filter = SchemaFilter::AllowAll;
68        assert!(filter.is_allowed("ANY_SCHEMA"));
69        assert!(filter.is_allowed("SYS"));
70        assert!(filter.is_allowed("system"));
71    }
72
73    #[test]
74    fn test_whitelist() {
75        let allowed: HashSet<String> = ["ALLOWED_SCHEMA", "APP"]
76            .iter()
77            .map(|s| (*s).to_string())
78            .collect();
79        let filter = SchemaFilter::Whitelist(allowed);
80
81        assert!(filter.is_allowed("ALLOWED_SCHEMA"));
82        assert!(filter.is_allowed("allowed_schema")); // case insensitive
83        assert!(filter.is_allowed("APP"));
84        assert!(!filter.is_allowed("OTHER"));
85        assert!(!filter.is_allowed("SYS"));
86    }
87
88    #[test]
89    fn test_blacklist() {
90        let denied: HashSet<String> = ["SYS", "SYSTEM"].iter().map(|s| (*s).to_string()).collect();
91        let filter = SchemaFilter::Blacklist(denied);
92
93        assert!(!filter.is_allowed("SYS"));
94        assert!(!filter.is_allowed("sys")); // case insensitive
95        assert!(!filter.is_allowed("SYSTEM"));
96        assert!(filter.is_allowed("APP"));
97        assert!(filter.is_allowed("MY_SCHEMA"));
98    }
99
100    #[test]
101    fn test_from_config_whitelist() {
102        let schemas = vec!["SCHEMA1".to_string(), "SCHEMA2".to_string()];
103        let filter = SchemaFilter::from_config("whitelist", &schemas).unwrap();
104
105        assert!(filter.is_allowed("SCHEMA1"));
106        assert!(filter.is_allowed("schema2"));
107        assert!(!filter.is_allowed("OTHER"));
108    }
109
110    #[test]
111    fn test_from_config_blacklist() {
112        let schemas = vec!["SYS".to_string()];
113        let filter = SchemaFilter::from_config("blacklist", &schemas).unwrap();
114
115        assert!(!filter.is_allowed("SYS"));
116        assert!(filter.is_allowed("APP"));
117    }
118
119    #[test]
120    fn test_from_config_none() {
121        let filter = SchemaFilter::from_config("none", &[]).unwrap();
122        assert!(filter.is_allowed("ANY"));
123    }
124
125    #[test]
126    fn test_from_config_whitelist_requires_schemas() {
127        let result = SchemaFilter::from_config("whitelist", &[]);
128        assert!(result.is_err());
129    }
130
131    #[test]
132    fn test_from_config_invalid_mode() {
133        let result = SchemaFilter::from_config("invalid", &[]);
134        assert!(result.is_err());
135    }
136
137    #[test]
138    fn test_validate() {
139        let denied: HashSet<String> = ["SYS"].iter().map(|s| (*s).to_string()).collect();
140        let filter = SchemaFilter::Blacklist(denied);
141
142        assert!(filter.validate("APP").is_ok());
143        assert!(filter.validate("SYS").is_err());
144    }
145}