Skip to main content

heliosdb_proxy/rewriter/
config.rs

1//! Query Rewriter Configuration
2//!
3//! Configuration types for the query rewriting system.
4
5use super::rules::RewriteRule;
6use std::time::Duration;
7
8/// Query rewriter configuration
9#[derive(Debug, Clone)]
10pub struct RewriterConfig {
11    /// Enable query rewriting
12    pub enabled: bool,
13
14    /// Log rewrite operations
15    pub log_rewrites: bool,
16
17    /// Log rewrite errors
18    pub log_errors: bool,
19
20    /// Rewrite rules
21    pub rules: Vec<RewriteRule>,
22
23    /// Automatically expand SELECT * to column list
24    pub expand_select_star: bool,
25
26    /// Add default LIMIT to queries without one
27    pub add_default_limit: bool,
28
29    /// Default LIMIT value
30    pub default_limit: u32,
31
32    /// Maximum query length to process
33    pub max_query_length: usize,
34
35    /// Cache rewritten queries by fingerprint
36    pub cache_enabled: bool,
37
38    /// Cache TTL
39    pub cache_ttl: Duration,
40
41    /// Maximum cache entries
42    pub max_cache_entries: usize,
43
44    /// Agent query safety rules
45    pub agent_safety: AgentSafetyConfig,
46}
47
48impl Default for RewriterConfig {
49    fn default() -> Self {
50        Self {
51            enabled: false,
52            log_rewrites: false,
53            log_errors: true,
54            rules: Vec::new(),
55            expand_select_star: false,
56            add_default_limit: false,
57            default_limit: 1000,
58            max_query_length: 1_000_000,
59            cache_enabled: true,
60            cache_ttl: Duration::from_secs(300),
61            max_cache_entries: 10000,
62            agent_safety: AgentSafetyConfig::default(),
63        }
64    }
65}
66
67impl RewriterConfig {
68    /// Create a new enabled config
69    pub fn enabled() -> Self {
70        Self {
71            enabled: true,
72            ..Default::default()
73        }
74    }
75
76    /// Create a builder
77    pub fn builder() -> RewriterConfigBuilder {
78        RewriterConfigBuilder::new()
79    }
80}
81
82/// Builder for RewriterConfig
83#[derive(Default)]
84pub struct RewriterConfigBuilder {
85    config: RewriterConfig,
86}
87
88impl RewriterConfigBuilder {
89    /// Create a new builder
90    pub fn new() -> Self {
91        Self {
92            config: RewriterConfig {
93                enabled: true,
94                ..Default::default()
95            },
96        }
97    }
98
99    /// Enable/disable rewriting
100    pub fn enabled(mut self, enabled: bool) -> Self {
101        self.config.enabled = enabled;
102        self
103    }
104
105    /// Log rewrites
106    pub fn log_rewrites(mut self, log: bool) -> Self {
107        self.config.log_rewrites = log;
108        self
109    }
110
111    /// Log errors
112    pub fn log_errors(mut self, log: bool) -> Self {
113        self.config.log_errors = log;
114        self
115    }
116
117    /// Add a rule
118    pub fn rule(mut self, rule: RewriteRule) -> Self {
119        self.config.rules.push(rule);
120        self
121    }
122
123    /// Add multiple rules
124    pub fn rules(mut self, rules: Vec<RewriteRule>) -> Self {
125        self.config.rules.extend(rules);
126        self
127    }
128
129    /// Enable SELECT * expansion
130    pub fn expand_select_star(mut self, enabled: bool) -> Self {
131        self.config.expand_select_star = enabled;
132        self
133    }
134
135    /// Enable default LIMIT
136    pub fn add_default_limit(mut self, enabled: bool) -> Self {
137        self.config.add_default_limit = enabled;
138        self
139    }
140
141    /// Set default LIMIT value
142    pub fn default_limit(mut self, limit: u32) -> Self {
143        self.config.default_limit = limit;
144        self
145    }
146
147    /// Set max query length
148    pub fn max_query_length(mut self, length: usize) -> Self {
149        self.config.max_query_length = length;
150        self
151    }
152
153    /// Enable caching
154    pub fn cache_enabled(mut self, enabled: bool) -> Self {
155        self.config.cache_enabled = enabled;
156        self
157    }
158
159    /// Set cache TTL
160    pub fn cache_ttl(mut self, ttl: Duration) -> Self {
161        self.config.cache_ttl = ttl;
162        self
163    }
164
165    /// Set agent safety config
166    pub fn agent_safety(mut self, config: AgentSafetyConfig) -> Self {
167        self.config.agent_safety = config;
168        self
169    }
170
171    /// Build the config
172    pub fn build(self) -> RewriterConfig {
173        self.config
174    }
175}
176
177/// Agent query safety configuration
178#[derive(Debug, Clone)]
179pub struct AgentSafetyConfig {
180    /// Enable agent safety rules
181    pub enabled: bool,
182
183    /// Maximum rows for agent queries
184    pub max_rows: u32,
185
186    /// Maximum query timeout for agents
187    pub max_timeout: Duration,
188
189    /// Forbidden tables for agents
190    pub forbidden_tables: Vec<String>,
191
192    /// Required WHERE clause tables
193    pub require_where_tables: Vec<String>,
194
195    /// Block DDL for agents
196    pub block_ddl: bool,
197
198    /// Block admin commands for agents
199    pub block_admin: bool,
200}
201
202impl Default for AgentSafetyConfig {
203    fn default() -> Self {
204        Self {
205            enabled: true,
206            max_rows: 10000,
207            max_timeout: Duration::from_secs(30),
208            forbidden_tables: vec![
209                "pg_catalog.*".to_string(),
210                "information_schema.*".to_string(),
211                "system.*".to_string(),
212                "secrets".to_string(),
213                "credentials".to_string(),
214            ],
215            require_where_tables: Vec::new(),
216            block_ddl: true,
217            block_admin: true,
218        }
219    }
220}
221
222impl AgentSafetyConfig {
223    /// Create a permissive config (for trusted agents)
224    pub fn permissive() -> Self {
225        Self {
226            enabled: true,
227            max_rows: 100000,
228            max_timeout: Duration::from_secs(300),
229            forbidden_tables: Vec::new(),
230            require_where_tables: Vec::new(),
231            block_ddl: false,
232            block_admin: false,
233        }
234    }
235
236    /// Create a restrictive config (for untrusted agents)
237    pub fn restrictive() -> Self {
238        Self {
239            enabled: true,
240            max_rows: 1000,
241            max_timeout: Duration::from_secs(10),
242            forbidden_tables: vec![
243                "pg_catalog.*".to_string(),
244                "information_schema.*".to_string(),
245                "system.*".to_string(),
246                "secrets".to_string(),
247                "credentials".to_string(),
248                "users".to_string(),
249                "accounts".to_string(),
250            ],
251            require_where_tables: vec!["*".to_string()],
252            block_ddl: true,
253            block_admin: true,
254        }
255    }
256
257    /// Check if a table is forbidden
258    pub fn is_forbidden(&self, table: &str) -> bool {
259        for pattern in &self.forbidden_tables {
260            if pattern.ends_with("*") {
261                let prefix = &pattern[..pattern.len() - 1];
262                if table.starts_with(prefix) {
263                    return true;
264                }
265            } else if pattern == table {
266                return true;
267            }
268        }
269        false
270    }
271}
272
273/// Built-in rule templates
274#[derive(Debug, Clone, Copy, PartialEq, Eq)]
275pub enum BuiltinRule {
276    /// Add index hints
277    AddIndexHints,
278
279    /// Expand SELECT *
280    ExpandSelectStar,
281
282    /// Add default LIMIT
283    AddDefaultLimit,
284
285    /// Add tenant filter
286    AddTenantFilter,
287
288    /// Route to specific branch
289    RouteToBranch,
290
291    /// Agent safety limits
292    AgentSafety,
293}
294
295impl BuiltinRule {
296    /// Get rule ID
297    pub fn id(&self) -> &'static str {
298        match self {
299            Self::AddIndexHints => "builtin:add_index_hints",
300            Self::ExpandSelectStar => "builtin:expand_select_star",
301            Self::AddDefaultLimit => "builtin:add_default_limit",
302            Self::AddTenantFilter => "builtin:add_tenant_filter",
303            Self::RouteToBranch => "builtin:route_to_branch",
304            Self::AgentSafety => "builtin:agent_safety",
305        }
306    }
307
308    /// Get rule description
309    pub fn description(&self) -> &'static str {
310        match self {
311            Self::AddIndexHints => "Add index hints based on query patterns",
312            Self::ExpandSelectStar => "Expand SELECT * to column list",
313            Self::AddDefaultLimit => "Add LIMIT to queries without one",
314            Self::AddTenantFilter => "Add tenant ID filter for multi-tenancy",
315            Self::RouteToBranch => "Add branch routing hints",
316            Self::AgentSafety => "Apply safety limits for AI agent queries",
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_default_config() {
327        let config = RewriterConfig::default();
328        assert!(!config.enabled);
329        assert!(config.rules.is_empty());
330    }
331
332    #[test]
333    fn test_config_builder() {
334        let config = RewriterConfig::builder()
335            .enabled(true)
336            .log_rewrites(true)
337            .add_default_limit(true)
338            .default_limit(500)
339            .build();
340
341        assert!(config.enabled);
342        assert!(config.log_rewrites);
343        assert!(config.add_default_limit);
344        assert_eq!(config.default_limit, 500);
345    }
346
347    #[test]
348    fn test_agent_safety_forbidden_tables() {
349        let config = AgentSafetyConfig::default();
350
351        assert!(config.is_forbidden("pg_catalog.pg_tables"));
352        assert!(config.is_forbidden("secrets"));
353        assert!(!config.is_forbidden("users"));
354    }
355
356    #[test]
357    fn test_restrictive_agent_config() {
358        let config = AgentSafetyConfig::restrictive();
359
360        assert!(config.is_forbidden("users"));
361        assert!(config.block_ddl);
362        assert_eq!(config.max_rows, 1000);
363    }
364}