1use crate::types::{
10 CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
11};
12use sqlparser::ast::{
13 AssignmentTarget, Expr, FromTable, GroupByExpr, Join, LimitClause, ObjectName, Query, Select,
14 SelectItem, SetExpr, Statement, TableFactor, TableObject, TableWithJoins,
15};
16use sqlparser::dialect::{Dialect, GenericDialect};
17use sqlparser::parser::Parser;
18use std::collections::HashSet;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SqlStatementType {
23 Select,
25 Insert,
27 Update,
29 Delete,
31 Ddl,
33 Other,
35}
36
37impl SqlStatementType {
38 pub fn as_str(&self) -> &'static str {
41 match self {
42 Self::Select => "SELECT",
43 Self::Insert => "INSERT",
44 Self::Update => "UPDATE",
45 Self::Delete => "DELETE",
46 Self::Ddl => "DDL",
47 Self::Other => "OTHER",
48 }
49 }
50
51 pub fn is_read_only(&self) -> bool {
53 matches!(self, Self::Select)
54 }
55
56 pub fn is_write(&self) -> bool {
58 matches!(self, Self::Insert | Self::Update)
59 }
60
61 pub fn is_delete(&self) -> bool {
63 matches!(self, Self::Delete)
64 }
65
66 pub fn is_admin(&self) -> bool {
68 matches!(self, Self::Ddl)
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct SqlStatementInfo {
75 pub statement_type: SqlStatementType,
77
78 pub verb: String,
81
82 pub tables: HashSet<String>,
84
85 pub columns: HashSet<String>,
87
88 pub has_where: bool,
90
91 pub has_limit: bool,
93
94 pub has_order_by: bool,
96
97 pub has_aggregation: bool,
99
100 pub join_count: u32,
102
103 pub subquery_count: u32,
105
106 pub estimated_rows: u64,
108
109 pub sql_length: usize,
111}
112
113#[derive(Debug, Clone)]
115pub struct SqlValidator {
116 dialect: DialectBox,
117 default_row_estimate: u64,
118}
119
120impl Default for SqlValidator {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl SqlValidator {
127 pub fn new() -> Self {
129 Self {
130 dialect: DialectBox::Generic,
131 default_row_estimate: 1000,
132 }
133 }
134
135 pub fn validate(&self, sql: &str) -> Result<SqlStatementInfo, ValidationError> {
140 let trimmed = sql.trim();
141 if trimmed.is_empty() {
142 return Err(ValidationError::ParseError {
143 message: "SQL statement is empty".to_string(),
144 line: 1,
145 column: 1,
146 });
147 }
148
149 let statements = Parser::parse_sql(self.dialect.as_dialect(), trimmed).map_err(|e| {
150 ValidationError::ParseError {
151 message: format!("SQL parse error: {}", e),
152 line: 1,
153 column: 1,
154 }
155 })?;
156
157 match statements.len() {
158 0 => Err(ValidationError::ParseError {
159 message: "SQL contains no statements".to_string(),
160 line: 1,
161 column: 1,
162 }),
163 1 => Ok(self.analyze_statement(&statements[0], trimmed)),
164 n => Err(ValidationError::ParseError {
165 message: format!("SQL Code Mode validates one statement at a time; got {}", n),
166 line: 1,
167 column: 1,
168 }),
169 }
170 }
171
172 pub fn analyze_security(&self, info: &SqlStatementInfo) -> SecurityAnalysis {
178 let mut issues: Vec<SecurityIssue> = Vec::new();
179
180 if (info.statement_type.is_write() || info.statement_type.is_delete()) && !info.has_where {
182 issues.push(SecurityIssue::new(
183 SecurityIssueType::UnboundedQuery,
184 format!(
185 "{} statement has no WHERE clause — affects all rows in the table",
186 info.verb
187 ),
188 ));
189 }
190
191 if info.statement_type.is_read_only() && !info.has_limit {
193 issues.push(SecurityIssue::new(
194 SecurityIssueType::UnboundedQuery,
195 format!(
196 "{} statement has no LIMIT — result set may be large",
197 info.verb
198 ),
199 ));
200 }
201
202 if info.join_count > 5 {
204 issues.push(SecurityIssue::new(
205 SecurityIssueType::HighComplexity,
206 format!(
207 "Query has {} JOINs, which may be expensive to execute",
208 info.join_count
209 ),
210 ));
211 }
212 if info.subquery_count > 3 {
213 issues.push(SecurityIssue::new(
214 SecurityIssueType::DeepNesting,
215 format!("Query has {} nested subqueries", info.subquery_count),
216 ));
217 }
218
219 let complexity = estimate_complexity(info);
220
221 SecurityAnalysis {
222 is_read_only: info.statement_type.is_read_only(),
223 tables_accessed: info.tables.clone(),
224 fields_accessed: info.columns.clone(),
225 has_aggregation: info.has_aggregation,
226 has_subqueries: info.subquery_count > 0,
227 estimated_complexity: complexity,
228 potential_issues: issues,
229 estimated_rows: Some(info.estimated_rows),
230 }
231 }
232
233 pub fn to_code_type(&self, info: &SqlStatementInfo) -> CodeType {
235 if info.statement_type.is_read_only() {
236 CodeType::SqlQuery
237 } else {
238 CodeType::SqlMutation
239 }
240 }
241
242 fn analyze_statement(&self, stmt: &Statement, sql: &str) -> SqlStatementInfo {
243 let mut info = SqlStatementInfo {
244 statement_type: SqlStatementType::Other,
245 verb: verb_for(stmt),
246 tables: HashSet::new(),
247 columns: HashSet::new(),
248 has_where: false,
249 has_limit: false,
250 has_order_by: false,
251 has_aggregation: false,
252 join_count: 0,
253 subquery_count: 0,
254 estimated_rows: self.default_row_estimate,
255 sql_length: sql.len(),
256 };
257
258 match stmt {
259 Statement::Query(query) => {
260 info.statement_type = SqlStatementType::Select;
261 self.analyze_query(query, &mut info);
262 },
263 Statement::Insert(insert) => {
264 info.statement_type = SqlStatementType::Insert;
265 if let TableObject::TableName(name) = &insert.table {
266 add_object_name(&mut info.tables, name);
267 }
268 for col in &insert.columns {
269 info.columns.insert(col.to_string());
272 }
273 if let Some(source) = &insert.source {
274 self.analyze_query(source, &mut info);
275 }
276 },
277 Statement::Update(update) => {
278 info.statement_type = SqlStatementType::Update;
279 self.analyze_table_with_joins(&update.table, &mut info);
280 for assignment in &update.assignments {
281 match &assignment.target {
282 AssignmentTarget::ColumnName(name) => {
283 add_object_name(&mut info.columns, name);
284 },
285 AssignmentTarget::Tuple(names) => {
286 for n in names {
287 add_object_name(&mut info.columns, n);
288 }
289 },
290 }
291 self.analyze_expr(&assignment.value, &mut info);
292 }
293 if let Some(expr) = &update.selection {
294 info.has_where = true;
295 self.analyze_expr(expr, &mut info);
296 }
297 },
298 Statement::Delete(delete) => {
299 info.statement_type = SqlStatementType::Delete;
300 match &delete.from {
301 FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => {
302 for t in tables {
303 self.analyze_table_with_joins(t, &mut info);
304 }
305 },
306 }
307 for t in &delete.tables {
309 add_object_name(&mut info.tables, t);
310 }
311 if let Some(expr) = &delete.selection {
312 info.has_where = true;
313 self.analyze_expr(expr, &mut info);
314 }
315 },
316 Statement::Truncate(truncate) => {
317 info.statement_type = SqlStatementType::Delete;
318 for tn in &truncate.table_names {
319 add_object_name(&mut info.tables, &tn.name);
320 }
321 },
322 Statement::CreateTable(create) => {
323 info.statement_type = SqlStatementType::Ddl;
324 add_object_name(&mut info.tables, &create.name);
325 },
326 Statement::AlterTable(alter) => {
327 info.statement_type = SqlStatementType::Ddl;
328 add_object_name(&mut info.tables, &alter.name);
329 },
330 Statement::Drop { .. }
331 | Statement::CreateIndex(_)
332 | Statement::CreateView { .. }
333 | Statement::Grant { .. }
334 | Statement::Revoke { .. } => {
335 info.statement_type = SqlStatementType::Ddl;
336 },
337 _ => {
338 },
340 }
341
342 info
343 }
344
345 fn analyze_query(&self, query: &Query, info: &mut SqlStatementInfo) {
346 if query.order_by.is_some() {
347 info.has_order_by = true;
348 }
349 if let Some(limit_clause) = &query.limit_clause {
350 info.has_limit = true;
351 let limit_expr = match limit_clause {
352 LimitClause::LimitOffset { limit, .. } => limit.as_ref(),
353 LimitClause::OffsetCommaLimit { limit, .. } => Some(limit),
354 };
355 if let Some(Expr::Value(v)) = limit_expr {
356 if let sqlparser::ast::Value::Number(n, _) = &v.value {
357 if let Ok(parsed) = n.parse::<u64>() {
358 info.estimated_rows = parsed;
359 }
360 }
361 }
362 }
363
364 self.analyze_set_expr(&query.body, info);
365 }
366
367 fn analyze_set_expr(&self, set_expr: &SetExpr, info: &mut SqlStatementInfo) {
368 match set_expr {
369 SetExpr::Select(select) => self.analyze_select(select, info),
370 SetExpr::Query(inner) => {
371 info.subquery_count += 1;
372 self.analyze_query(inner, info);
373 },
374 SetExpr::SetOperation { left, right, .. } => {
375 self.analyze_set_expr(left, info);
376 self.analyze_set_expr(right, info);
377 },
378 _ => {},
379 }
380 }
381
382 fn analyze_select(&self, select: &Select, info: &mut SqlStatementInfo) {
383 for item in &select.projection {
385 match item {
386 SelectItem::UnnamedExpr(expr) => self.analyze_expr(expr, info),
387 SelectItem::ExprWithAlias { expr, .. } => self.analyze_expr(expr, info),
388 SelectItem::ExprWithAliases { expr, .. } => self.analyze_expr(expr, info),
391 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
392 info.columns.insert("*".to_string());
393 },
394 }
395 }
396
397 for table in &select.from {
399 self.analyze_table_with_joins(table, info);
400 }
401
402 if let Some(expr) = &select.selection {
404 info.has_where = true;
405 self.analyze_expr(expr, info);
406 }
407
408 if !group_by_is_empty(&select.group_by) {
410 info.has_aggregation = true;
411 }
412 }
413
414 fn analyze_table_with_joins(&self, item: &TableWithJoins, info: &mut SqlStatementInfo) {
415 self.analyze_table_factor(&item.relation, info);
416 for join in &item.joins {
417 info.join_count += 1;
418 self.analyze_join(join, info);
419 }
420 }
421
422 fn analyze_join(&self, join: &Join, info: &mut SqlStatementInfo) {
423 self.analyze_table_factor(&join.relation, info);
424 }
425
426 fn analyze_table_factor(&self, factor: &TableFactor, info: &mut SqlStatementInfo) {
427 match factor {
428 TableFactor::Table { name, .. } => add_object_name(&mut info.tables, name),
429 TableFactor::Derived { subquery, .. } => {
430 info.subquery_count += 1;
431 self.analyze_query(subquery, info);
432 },
433 TableFactor::NestedJoin {
434 table_with_joins, ..
435 } => self.analyze_table_with_joins(table_with_joins, info),
436 _ => {},
437 }
438 }
439
440 fn analyze_expr(&self, expr: &Expr, info: &mut SqlStatementInfo) {
441 match expr {
442 Expr::Identifier(id) => {
443 info.columns.insert(id.value.clone());
444 },
445 Expr::CompoundIdentifier(ids) => {
446 if let Some(last) = ids.last() {
447 info.columns.insert(last.value.clone());
448 }
449 },
450 Expr::Subquery(q)
451 | Expr::Exists { subquery: q, .. }
452 | Expr::InSubquery { subquery: q, .. } => {
453 info.subquery_count += 1;
454 self.analyze_query(q, info);
455 },
456 Expr::Function(f) => {
457 let name = f.name.to_string().to_uppercase();
458 if matches!(
459 name.as_str(),
460 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "ARRAY_AGG" | "STRING_AGG"
461 ) {
462 info.has_aggregation = true;
463 }
464 },
465 _ => {},
466 }
467 }
468}
469
470fn estimate_complexity(info: &SqlStatementInfo) -> Complexity {
471 let joins = info.join_count;
472 let subs = info.subquery_count;
473 if joins >= 5 || subs >= 3 {
474 Complexity::High
475 } else if joins >= 2 || subs >= 1 || info.has_aggregation {
476 Complexity::Medium
477 } else {
478 Complexity::Low
479 }
480}
481
482fn group_by_is_empty(group_by: &GroupByExpr) -> bool {
483 match group_by {
484 GroupByExpr::All(_) => true,
485 GroupByExpr::Expressions(exprs, _) => exprs.is_empty(),
486 }
487}
488
489fn add_object_name(out: &mut HashSet<String>, name: &ObjectName) {
490 if let Some(last) = name.0.last() {
491 out.insert(last.to_string());
492 } else {
493 out.insert(name.to_string());
494 }
495}
496
497fn verb_for(stmt: &Statement) -> String {
498 match stmt {
499 Statement::Query(_) => "SELECT".to_string(),
500 Statement::Insert(_) => "INSERT".to_string(),
501 Statement::Update { .. } => "UPDATE".to_string(),
502 Statement::Delete(_) => "DELETE".to_string(),
503 Statement::Truncate { .. } => "TRUNCATE".to_string(),
504 Statement::CreateTable(_) => "CREATE TABLE".to_string(),
505 Statement::AlterTable { .. } => "ALTER TABLE".to_string(),
506 Statement::Drop { .. } => "DROP".to_string(),
507 Statement::CreateIndex(_) => "CREATE INDEX".to_string(),
508 Statement::CreateView { .. } => "CREATE VIEW".to_string(),
509 Statement::Grant { .. } => "GRANT".to_string(),
510 Statement::Revoke { .. } => "REVOKE".to_string(),
511 other => format!("{:?}", other)
512 .split('(')
513 .next()
514 .unwrap_or("OTHER")
515 .to_uppercase(),
516 }
517}
518
519#[derive(Debug, Clone)]
522enum DialectBox {
523 Generic,
524}
525
526impl DialectBox {
527 fn as_dialect(&self) -> &dyn Dialect {
528 match self {
529 Self::Generic => &GenericDialect {},
530 }
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn select_simple() {
540 let v = SqlValidator::new();
541 let info = v.validate("SELECT id, name FROM users").unwrap();
542 assert_eq!(info.statement_type, SqlStatementType::Select);
543 assert!(info.tables.contains("users"));
544 assert!(info.columns.contains("id"));
545 assert!(info.columns.contains("name"));
546 assert!(!info.has_where);
547 assert!(!info.has_limit);
548 }
549
550 #[test]
551 fn select_with_where_limit_order() {
552 let v = SqlValidator::new();
553 let info = v
554 .validate("SELECT id FROM users WHERE active = 1 ORDER BY id LIMIT 10")
555 .unwrap();
556 assert!(info.has_where);
557 assert!(info.has_limit);
558 assert!(info.has_order_by);
559 assert_eq!(info.estimated_rows, 10);
560 }
561
562 #[test]
563 fn select_star() {
564 let v = SqlValidator::new();
565 let info = v.validate("SELECT * FROM users").unwrap();
566 assert!(info.columns.contains("*"));
567 }
568
569 #[test]
570 fn select_join_and_subquery() {
571 let v = SqlValidator::new();
572 let info = v
573 .validate(
574 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id \
575 WHERE u.id IN (SELECT id FROM admins)",
576 )
577 .unwrap();
578 assert_eq!(info.join_count, 1);
579 assert!(info.subquery_count >= 1);
580 assert!(info.tables.contains("users"));
581 assert!(info.tables.contains("orders"));
582 assert!(info.tables.contains("admins"));
583 }
584
585 #[test]
586 fn insert_extracts_table_and_columns() {
587 let v = SqlValidator::new();
588 let info = v
589 .validate("INSERT INTO users (id, name) VALUES (1, 'Alice')")
590 .unwrap();
591 assert_eq!(info.statement_type, SqlStatementType::Insert);
592 assert!(info.tables.contains("users"));
593 assert!(info.columns.contains("id"));
594 assert!(info.columns.contains("name"));
595 }
596
597 #[test]
598 fn update_without_where_flagged() {
599 let v = SqlValidator::new();
600 let info = v.validate("UPDATE users SET active = 0").unwrap();
601 assert_eq!(info.statement_type, SqlStatementType::Update);
602 assert!(!info.has_where);
603 let sa = v.analyze_security(&info);
604 assert!(sa
605 .potential_issues
606 .iter()
607 .any(|i| i.issue_type == SecurityIssueType::UnboundedQuery));
608 }
609
610 #[test]
611 fn update_with_where() {
612 let v = SqlValidator::new();
613 let info = v
614 .validate("UPDATE users SET active = 0 WHERE id = 1")
615 .unwrap();
616 assert_eq!(info.statement_type, SqlStatementType::Update);
617 assert!(info.has_where);
618 assert!(info.columns.contains("active"));
619 }
620
621 #[test]
622 fn delete_with_where() {
623 let v = SqlValidator::new();
624 let info = v.validate("DELETE FROM users WHERE id = 1").unwrap();
625 assert_eq!(info.statement_type, SqlStatementType::Delete);
626 assert!(info.has_where);
627 }
628
629 #[test]
630 fn ddl_is_admin() {
631 let v = SqlValidator::new();
632 let info = v.validate("CREATE TABLE foo (id INT)").unwrap();
633 assert_eq!(info.statement_type, SqlStatementType::Ddl);
634 assert!(info.statement_type.is_admin());
635 }
636
637 #[test]
638 fn empty_sql_rejected() {
639 let v = SqlValidator::new();
640 assert!(matches!(
641 v.validate(""),
642 Err(ValidationError::ParseError { .. })
643 ));
644 assert!(matches!(
645 v.validate(" "),
646 Err(ValidationError::ParseError { .. })
647 ));
648 }
649
650 #[test]
651 fn syntax_error_rejected() {
652 let v = SqlValidator::new();
653 assert!(matches!(
654 v.validate("SELEC id FRM users"),
655 Err(ValidationError::ParseError { .. })
656 ));
657 }
658
659 #[test]
660 fn multiple_statements_rejected() {
661 let v = SqlValidator::new();
662 assert!(matches!(
663 v.validate("SELECT 1; SELECT 2"),
664 Err(ValidationError::ParseError { .. })
665 ));
666 }
667
668 #[test]
669 fn aggregation_detected() {
670 let v = SqlValidator::new();
671 let info = v.validate("SELECT COUNT(*) FROM users").unwrap();
672 assert!(info.has_aggregation);
673 }
674
675 #[test]
676 fn group_by_detected() {
677 let v = SqlValidator::new();
678 let info = v
679 .validate("SELECT role, COUNT(*) FROM users GROUP BY role")
680 .unwrap();
681 assert!(info.has_aggregation);
682 }
683}