1use super::rules::RewriteRule;
6use std::time::Duration;
7
8#[derive(Debug, Clone)]
10pub struct RewriterConfig {
11 pub enabled: bool,
13
14 pub log_rewrites: bool,
16
17 pub log_errors: bool,
19
20 pub rules: Vec<RewriteRule>,
22
23 pub expand_select_star: bool,
25
26 pub add_default_limit: bool,
28
29 pub default_limit: u32,
31
32 pub max_query_length: usize,
34
35 pub cache_enabled: bool,
37
38 pub cache_ttl: Duration,
40
41 pub max_cache_entries: usize,
43
44 pub agent_safety: AgentSafetyConfig,
46}
47
48impl Default for RewriterConfig {
49 fn default() -> Self {
50 Self {
51 enabled: false,
52 log_rewrites: false,
53 log_errors: true,
54 rules: Vec::new(),
55 expand_select_star: false,
56 add_default_limit: false,
57 default_limit: 1000,
58 max_query_length: 1_000_000,
59 cache_enabled: true,
60 cache_ttl: Duration::from_secs(300),
61 max_cache_entries: 10000,
62 agent_safety: AgentSafetyConfig::default(),
63 }
64 }
65}
66
67impl RewriterConfig {
68 pub fn enabled() -> Self {
70 Self {
71 enabled: true,
72 ..Default::default()
73 }
74 }
75
76 pub fn builder() -> RewriterConfigBuilder {
78 RewriterConfigBuilder::new()
79 }
80}
81
82#[derive(Default)]
84pub struct RewriterConfigBuilder {
85 config: RewriterConfig,
86}
87
88impl RewriterConfigBuilder {
89 pub fn new() -> Self {
91 Self {
92 config: RewriterConfig {
93 enabled: true,
94 ..Default::default()
95 },
96 }
97 }
98
99 pub fn enabled(mut self, enabled: bool) -> Self {
101 self.config.enabled = enabled;
102 self
103 }
104
105 pub fn log_rewrites(mut self, log: bool) -> Self {
107 self.config.log_rewrites = log;
108 self
109 }
110
111 pub fn log_errors(mut self, log: bool) -> Self {
113 self.config.log_errors = log;
114 self
115 }
116
117 pub fn rule(mut self, rule: RewriteRule) -> Self {
119 self.config.rules.push(rule);
120 self
121 }
122
123 pub fn rules(mut self, rules: Vec<RewriteRule>) -> Self {
125 self.config.rules.extend(rules);
126 self
127 }
128
129 pub fn expand_select_star(mut self, enabled: bool) -> Self {
131 self.config.expand_select_star = enabled;
132 self
133 }
134
135 pub fn add_default_limit(mut self, enabled: bool) -> Self {
137 self.config.add_default_limit = enabled;
138 self
139 }
140
141 pub fn default_limit(mut self, limit: u32) -> Self {
143 self.config.default_limit = limit;
144 self
145 }
146
147 pub fn max_query_length(mut self, length: usize) -> Self {
149 self.config.max_query_length = length;
150 self
151 }
152
153 pub fn cache_enabled(mut self, enabled: bool) -> Self {
155 self.config.cache_enabled = enabled;
156 self
157 }
158
159 pub fn cache_ttl(mut self, ttl: Duration) -> Self {
161 self.config.cache_ttl = ttl;
162 self
163 }
164
165 pub fn agent_safety(mut self, config: AgentSafetyConfig) -> Self {
167 self.config.agent_safety = config;
168 self
169 }
170
171 pub fn build(self) -> RewriterConfig {
173 self.config
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct AgentSafetyConfig {
180 pub enabled: bool,
182
183 pub max_rows: u32,
185
186 pub max_timeout: Duration,
188
189 pub forbidden_tables: Vec<String>,
191
192 pub require_where_tables: Vec<String>,
194
195 pub block_ddl: bool,
197
198 pub block_admin: bool,
200}
201
202impl Default for AgentSafetyConfig {
203 fn default() -> Self {
204 Self {
205 enabled: true,
206 max_rows: 10000,
207 max_timeout: Duration::from_secs(30),
208 forbidden_tables: vec![
209 "pg_catalog.*".to_string(),
210 "information_schema.*".to_string(),
211 "system.*".to_string(),
212 "secrets".to_string(),
213 "credentials".to_string(),
214 ],
215 require_where_tables: Vec::new(),
216 block_ddl: true,
217 block_admin: true,
218 }
219 }
220}
221
222impl AgentSafetyConfig {
223 pub fn permissive() -> Self {
225 Self {
226 enabled: true,
227 max_rows: 100000,
228 max_timeout: Duration::from_secs(300),
229 forbidden_tables: Vec::new(),
230 require_where_tables: Vec::new(),
231 block_ddl: false,
232 block_admin: false,
233 }
234 }
235
236 pub fn restrictive() -> Self {
238 Self {
239 enabled: true,
240 max_rows: 1000,
241 max_timeout: Duration::from_secs(10),
242 forbidden_tables: vec![
243 "pg_catalog.*".to_string(),
244 "information_schema.*".to_string(),
245 "system.*".to_string(),
246 "secrets".to_string(),
247 "credentials".to_string(),
248 "users".to_string(),
249 "accounts".to_string(),
250 ],
251 require_where_tables: vec!["*".to_string()],
252 block_ddl: true,
253 block_admin: true,
254 }
255 }
256
257 pub fn is_forbidden(&self, table: &str) -> bool {
259 for pattern in &self.forbidden_tables {
260 if pattern.ends_with("*") {
261 let prefix = &pattern[..pattern.len() - 1];
262 if table.starts_with(prefix) {
263 return true;
264 }
265 } else if pattern == table {
266 return true;
267 }
268 }
269 false
270 }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
275pub enum BuiltinRule {
276 AddIndexHints,
278
279 ExpandSelectStar,
281
282 AddDefaultLimit,
284
285 AddTenantFilter,
287
288 RouteToBranch,
290
291 AgentSafety,
293}
294
295impl BuiltinRule {
296 pub fn id(&self) -> &'static str {
298 match self {
299 Self::AddIndexHints => "builtin:add_index_hints",
300 Self::ExpandSelectStar => "builtin:expand_select_star",
301 Self::AddDefaultLimit => "builtin:add_default_limit",
302 Self::AddTenantFilter => "builtin:add_tenant_filter",
303 Self::RouteToBranch => "builtin:route_to_branch",
304 Self::AgentSafety => "builtin:agent_safety",
305 }
306 }
307
308 pub fn description(&self) -> &'static str {
310 match self {
311 Self::AddIndexHints => "Add index hints based on query patterns",
312 Self::ExpandSelectStar => "Expand SELECT * to column list",
313 Self::AddDefaultLimit => "Add LIMIT to queries without one",
314 Self::AddTenantFilter => "Add tenant ID filter for multi-tenancy",
315 Self::RouteToBranch => "Add branch routing hints",
316 Self::AgentSafety => "Apply safety limits for AI agent queries",
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_default_config() {
327 let config = RewriterConfig::default();
328 assert!(!config.enabled);
329 assert!(config.rules.is_empty());
330 }
331
332 #[test]
333 fn test_config_builder() {
334 let config = RewriterConfig::builder()
335 .enabled(true)
336 .log_rewrites(true)
337 .add_default_limit(true)
338 .default_limit(500)
339 .build();
340
341 assert!(config.enabled);
342 assert!(config.log_rewrites);
343 assert!(config.add_default_limit);
344 assert_eq!(config.default_limit, 500);
345 }
346
347 #[test]
348 fn test_agent_safety_forbidden_tables() {
349 let config = AgentSafetyConfig::default();
350
351 assert!(config.is_forbidden("pg_catalog.pg_tables"));
352 assert!(config.is_forbidden("secrets"));
353 assert!(!config.is_forbidden("users"));
354 }
355
356 #[test]
357 fn test_restrictive_agent_config() {
358 let config = AgentSafetyConfig::restrictive();
359
360 assert!(config.is_forbidden("users"));
361 assert!(config.block_ddl);
362 assert_eq!(config.max_rows, 1000);
363 }
364}