1use regex::{Regex, RegexBuilder};
19use tracing::warn;
20
21use chio_guards::{extract_action, ToolAction};
22use chio_kernel::{GuardContext, KernelError, Verdict};
23
24use crate::config::{SqlGuardConfig, SqlOperation};
25use crate::error::SqlGuardDenyReason;
26use crate::sql_parser::{self, SqlAnalysis};
27
28pub struct SqlQueryGuard {
30 config: SqlGuardConfig,
31 denylist_regex: Vec<(String, Regex)>,
32}
33
34const MAX_DENYLISTED_PREDICATES: usize = 64;
35const MAX_DENYLISTED_PREDICATE_LEN: usize = 512;
36const MAX_DENYLISTED_PREDICATE_COMPLEXITY: usize = 96;
37const DENYLISTED_PREDICATE_REGEX_SIZE_LIMIT: usize = 1 << 20;
38const DENYLISTED_PREDICATE_DFA_SIZE_LIMIT: usize = 1 << 20;
39
40impl SqlQueryGuard {
41 pub fn new(config: SqlGuardConfig) -> Self {
47 match Self::try_new(config) {
48 Ok(guard) => guard,
49 Err(error) => {
50 warn!(
51 target: "chio.data-guards.sql",
52 error = %error,
53 "invalid sql-query-guard config; constructing fail-closed deny-all guard"
54 );
55 Self {
56 config: SqlGuardConfig::default(),
57 denylist_regex: Vec::new(),
58 }
59 }
60 }
61 }
62
63 pub fn try_new(config: SqlGuardConfig) -> Result<Self, String> {
65 if config.allow_all {
66 warn!(
67 target: "chio.data-guards.sql",
68 "sql-query-guard constructed with allow_all=true; fail-closed default disabled"
69 );
70 }
71
72 if config.denylisted_predicates.len() > MAX_DENYLISTED_PREDICATES {
73 return Err(format!(
74 "sql_query.denylisted_predicates allows at most {MAX_DENYLISTED_PREDICATES} patterns"
75 ));
76 }
77 let mut denylist_regex = Vec::with_capacity(config.denylisted_predicates.len());
78 for pattern in &config.denylisted_predicates {
79 let trimmed = pattern.trim();
80 if trimmed.is_empty() {
81 return Err("sql_query.denylisted_predicates cannot contain empty patterns".into());
82 }
83 if trimmed.len() > MAX_DENYLISTED_PREDICATE_LEN {
84 return Err(format!(
85 "sql_query.denylisted_predicates entries must be at most {MAX_DENYLISTED_PREDICATE_LEN} characters"
86 ));
87 }
88 let complexity = predicate_pattern_complexity(trimmed);
89 if complexity > MAX_DENYLISTED_PREDICATE_COMPLEXITY {
90 return Err(format!(
91 "sql_query.denylisted_predicates entries must have complexity at most {MAX_DENYLISTED_PREDICATE_COMPLEXITY}"
92 ));
93 }
94 let re = RegexBuilder::new(trimmed)
95 .case_insensitive(true)
96 .size_limit(DENYLISTED_PREDICATE_REGEX_SIZE_LIMIT)
97 .dfa_size_limit(DENYLISTED_PREDICATE_DFA_SIZE_LIMIT)
98 .build()
99 .map_err(|error| {
100 format!("invalid sql_query.denylisted_predicates entry `{trimmed}`: {error}")
101 })?;
102 denylist_regex.push((trimmed.to_string(), re));
103 }
104
105 Ok(Self {
106 config,
107 denylist_regex,
108 })
109 }
110
111 pub fn config(&self) -> &SqlGuardConfig {
114 &self.config
115 }
116
117 pub fn analyze(&self, query: &str) -> Result<SqlAnalysis, SqlGuardDenyReason> {
124 let analysis = sql_parser::parse(query, self.config.dialect)
126 .map_err(|e| SqlGuardDenyReason::ParseError { error: e })?;
127
128 if self.config.allow_all {
129 return Ok(analysis);
130 }
131
132 if self.config.is_empty() {
133 return Err(SqlGuardDenyReason::NoConfig);
134 }
135
136 self.enforce_operation(&analysis)?;
137 self.enforce_tables(&analysis)?;
138 self.enforce_columns(&analysis)?;
139 self.enforce_predicate_denylist(&analysis)?;
140 self.enforce_where_for_mutations(&analysis)?;
141
142 Ok(analysis)
143 }
144
145 fn enforce_operation(&self, analysis: &SqlAnalysis) -> Result<(), SqlGuardDenyReason> {
146 if self.config.operation_allowlist.is_empty() {
147 return Err(SqlGuardDenyReason::OperationNotAllowed {
150 operation: analysis.operation.as_str().to_string(),
151 });
152 }
153 if !self
154 .config
155 .operation_allowlist
156 .contains(&analysis.operation)
157 {
158 return Err(SqlGuardDenyReason::OperationNotAllowed {
159 operation: analysis.operation.as_str().to_string(),
160 });
161 }
162 Ok(())
163 }
164
165 fn enforce_tables(&self, analysis: &SqlAnalysis) -> Result<(), SqlGuardDenyReason> {
166 if self.config.table_allowlist.is_empty() {
167 return Err(SqlGuardDenyReason::TableNotAllowed {
168 table: analysis
169 .tables
170 .first()
171 .cloned()
172 .unwrap_or_else(|| "<none>".to_string()),
173 });
174 }
175 for table in &analysis.tables {
176 if !self.config.table_allowed(table) {
177 return Err(SqlGuardDenyReason::TableNotAllowed {
178 table: table.clone(),
179 });
180 }
181 }
182 Ok(())
183 }
184
185 fn enforce_columns(&self, analysis: &SqlAnalysis) -> Result<(), SqlGuardDenyReason> {
186 if analysis.operation != SqlOperation::Select {
187 return Ok(());
188 }
189 let Some(_) = self.config.column_allowlist.as_ref() else {
190 return Ok(());
191 };
192
193 for (table, column) in &analysis.projected_columns {
194 if column == "*" {
198 if self.config.table_has_column_allowlist(table) {
199 return Err(SqlGuardDenyReason::SelectStarDenied {
200 table: table.clone(),
201 });
202 }
203 continue;
204 }
205
206 if column == "?" {
221 return Err(SqlGuardDenyReason::ColumnNotAllowed {
222 table: table.clone(),
223 column: "?".to_string(),
224 });
225 }
226
227 match self.config.column_allowed(table, column) {
229 Some(true) => {}
230 Some(false) => {
231 return Err(SqlGuardDenyReason::ColumnNotAllowed {
232 table: table.clone(),
233 column: column.clone(),
234 })
235 }
236 None => {
237 }
239 }
240 }
241 Ok(())
242 }
243
244 fn enforce_predicate_denylist(&self, analysis: &SqlAnalysis) -> Result<(), SqlGuardDenyReason> {
245 if self.denylist_regex.is_empty() {
246 return Ok(());
247 }
248 if analysis.where_canonical.is_empty() {
249 return Ok(());
250 }
251 for (pattern, re) in &self.denylist_regex {
252 if re.is_match(&analysis.where_canonical) {
253 return Err(SqlGuardDenyReason::PredicateDenylisted {
254 pattern: pattern.clone(),
255 });
256 }
257 }
258 Ok(())
259 }
260
261 fn enforce_where_for_mutations(
262 &self,
263 analysis: &SqlAnalysis,
264 ) -> Result<(), SqlGuardDenyReason> {
265 if !self.config.require_where_for_mutations {
266 return Ok(());
267 }
268 let needs_where = matches!(
269 analysis.operation,
270 SqlOperation::Update | SqlOperation::Delete
271 );
272 if needs_where && !analysis.has_where {
273 return Err(SqlGuardDenyReason::MissingWhereClause {
274 operation: analysis.operation.as_str().to_string(),
275 });
276 }
277 Ok(())
278 }
279}
280
281fn predicate_pattern_complexity(pattern: &str) -> usize {
282 let mut score = 0usize;
283 let mut escaped = false;
284 for ch in pattern.chars() {
285 if escaped {
286 escaped = false;
287 continue;
288 }
289 match ch {
290 '\\' => escaped = true,
291 '|' | '*' | '+' | '?' => score = score.saturating_add(4),
292 '{' | '[' | '(' => score = score.saturating_add(2),
293 _ => {}
294 }
295 }
296 score
297}
298
299impl chio_kernel::Guard for SqlQueryGuard {
300 fn name(&self) -> &str {
301 "sql-query"
302 }
303
304 fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
305 let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
306 let (database, query) = match &action {
307 ToolAction::DatabaseQuery { database, query } => (database.as_str(), query.as_str()),
308 _ => return Ok(Verdict::Allow),
309 };
310
311 match self.analyze(query) {
312 Ok(_) => Ok(Verdict::Allow),
313 Err(reason) => {
314 warn!(
315 target: "chio.data-guards.sql",
316 database = %database,
317 code = reason.code(),
318 reason = %reason,
319 "sql-query-guard denied query"
320 );
321 Ok(Verdict::Deny)
322 }
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use std::collections::HashMap;
331
332 use crate::config::{SqlDialect, SqlGuardConfig, SqlOperation};
333
334 fn cfg_select_orders() -> SqlGuardConfig {
335 SqlGuardConfig {
336 dialect: SqlDialect::Generic,
337 operation_allowlist: vec![SqlOperation::Select],
338 table_allowlist: vec!["orders".to_string()],
339 ..Default::default()
340 }
341 }
342
343 #[test]
344 fn allow_select_from_allowed_table() {
345 let g = SqlQueryGuard::new(cfg_select_orders());
346 g.analyze("SELECT id FROM orders").expect("allowed");
347 }
348
349 #[test]
350 fn deny_select_from_unlisted_table() {
351 let g = SqlQueryGuard::new(cfg_select_orders());
352 let err = g.analyze("SELECT * FROM users").expect_err("denied");
353 assert!(matches!(err, SqlGuardDenyReason::TableNotAllowed { .. }));
354 }
355
356 #[test]
357 fn deny_drop_when_ddl_not_allowed() {
358 let g = SqlQueryGuard::new(cfg_select_orders());
359 let err = g.analyze("DROP TABLE orders").expect_err("denied");
360 assert!(matches!(
361 err,
362 SqlGuardDenyReason::OperationNotAllowed { .. }
363 ));
364 }
365
366 #[test]
367 fn deny_update_when_only_select_allowed() {
368 let g = SqlQueryGuard::new(cfg_select_orders());
369 let err = g
370 .analyze("UPDATE orders SET foo=1 WHERE id=1")
371 .expect_err("denied");
372 assert!(matches!(
373 err,
374 SqlGuardDenyReason::OperationNotAllowed { .. }
375 ));
376 }
377
378 #[test]
379 fn deny_malformed_sql() {
380 let g = SqlQueryGuard::new(cfg_select_orders());
381 let err = g.analyze("SELEKT oops").expect_err("denied");
382 assert!(matches!(err, SqlGuardDenyReason::ParseError { .. }));
383 }
384
385 #[test]
386 fn empty_config_denies() {
387 let g = SqlQueryGuard::new(SqlGuardConfig::default());
388 let err = g.analyze("SELECT 1").expect_err("denied");
389 assert!(matches!(err, SqlGuardDenyReason::NoConfig));
390 }
391
392 #[test]
393 fn allow_all_still_denies_parse_errors() {
394 let g = SqlQueryGuard::new(SqlGuardConfig {
395 allow_all: true,
396 ..Default::default()
397 });
398 let err = g.analyze("NOT SQL AT ALL ;;;;").expect_err("denied");
399 assert!(matches!(err, SqlGuardDenyReason::ParseError { .. }));
400 }
401
402 #[test]
403 fn allow_all_permits_well_formed_query() {
404 let g = SqlQueryGuard::new(SqlGuardConfig {
405 allow_all: true,
406 ..Default::default()
407 });
408 g.analyze("SELECT id FROM whatever").expect("allowed");
409 }
410
411 #[test]
412 fn column_allowlist_denies_unlisted_column() {
413 let mut map = HashMap::new();
414 map.insert(
415 "orders".to_string(),
416 vec!["id".to_string(), "total".to_string()],
417 );
418 let cfg = SqlGuardConfig {
419 operation_allowlist: vec![SqlOperation::Select],
420 table_allowlist: vec!["orders".into()],
421 column_allowlist: Some(map),
422 ..Default::default()
423 };
424 let g = SqlQueryGuard::new(cfg);
425 g.analyze("SELECT id, total FROM orders").expect("allowed");
426 let err = g
427 .analyze("SELECT id, email FROM orders")
428 .expect_err("denied");
429 assert!(matches!(err, SqlGuardDenyReason::ColumnNotAllowed { .. }));
430 }
431
432 #[test]
433 fn select_star_denied_when_column_allowlist_active() {
434 let mut map = HashMap::new();
435 map.insert("orders".to_string(), vec!["id".to_string()]);
436 let cfg = SqlGuardConfig {
437 operation_allowlist: vec![SqlOperation::Select],
438 table_allowlist: vec!["orders".into()],
439 column_allowlist: Some(map),
440 ..Default::default()
441 };
442 let g = SqlQueryGuard::new(cfg);
443 let err = g.analyze("SELECT * FROM orders").expect_err("denied");
444 assert!(matches!(err, SqlGuardDenyReason::SelectStarDenied { .. }));
445 }
446
447 #[test]
448 fn predicate_denylist_blocks_or_1_equals_1() {
449 let cfg = SqlGuardConfig {
450 operation_allowlist: vec![SqlOperation::Select],
451 table_allowlist: vec!["orders".into()],
452 denylisted_predicates: vec![r"\bor\s+1\s*=\s*1\b".to_string()],
453 ..Default::default()
454 };
455 let g = SqlQueryGuard::new(cfg);
456 let err = g
457 .analyze("SELECT id FROM orders WHERE id = 1 OR 1=1")
458 .expect_err("denied");
459 assert!(matches!(
460 err,
461 SqlGuardDenyReason::PredicateDenylisted { .. }
462 ));
463 }
464
465 #[test]
466 fn mutation_without_where_is_denied() {
467 let cfg = SqlGuardConfig {
468 operation_allowlist: vec![SqlOperation::Delete],
469 table_allowlist: vec!["orders".into()],
470 ..Default::default()
471 };
472 let g = SqlQueryGuard::new(cfg);
473 let err = g.analyze("DELETE FROM orders").expect_err("denied");
474 assert!(matches!(err, SqlGuardDenyReason::MissingWhereClause { .. }));
475 }
476
477 #[test]
478 fn mutation_where_optional_when_disabled() {
479 let cfg = SqlGuardConfig {
480 operation_allowlist: vec![SqlOperation::Delete],
481 table_allowlist: vec!["orders".into()],
482 require_where_for_mutations: false,
483 ..Default::default()
484 };
485 let g = SqlQueryGuard::new(cfg);
486 g.analyze("DELETE FROM orders").expect("allowed");
487 }
488}