1use crate::types::{
10 CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
11};
12use sqlparser::ast::{
13 AssignmentTarget, Expr, FromTable, GroupByExpr, Join, LimitClause, ObjectName, Query, Select,
14 SelectItem, SetExpr, Statement, TableFactor, TableObject, TableWithJoins,
15};
16use sqlparser::dialect::{Dialect, GenericDialect};
17use sqlparser::parser::Parser;
18use std::collections::HashSet;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SqlStatementType {
23 Select,
25 Insert,
27 Update,
29 Delete,
31 Ddl,
33 Other,
35}
36
37impl SqlStatementType {
38 pub fn as_str(&self) -> &'static str {
41 match self {
42 Self::Select => "SELECT",
43 Self::Insert => "INSERT",
44 Self::Update => "UPDATE",
45 Self::Delete => "DELETE",
46 Self::Ddl => "DDL",
47 Self::Other => "OTHER",
48 }
49 }
50
51 pub fn is_read_only(&self) -> bool {
53 matches!(self, Self::Select)
54 }
55
56 pub fn is_write(&self) -> bool {
58 matches!(self, Self::Insert | Self::Update)
59 }
60
61 pub fn is_delete(&self) -> bool {
63 matches!(self, Self::Delete)
64 }
65
66 pub fn is_admin(&self) -> bool {
68 matches!(self, Self::Ddl)
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct SqlStatementInfo {
75 pub statement_type: SqlStatementType,
77
78 pub verb: String,
81
82 pub tables: HashSet<String>,
84
85 pub columns: HashSet<String>,
87
88 pub has_where: bool,
90
91 pub has_limit: bool,
93
94 pub has_order_by: bool,
96
97 pub has_aggregation: bool,
99
100 pub join_count: u32,
102
103 pub subquery_count: u32,
105
106 pub estimated_rows: u64,
108
109 pub sql_length: usize,
111}
112
113#[derive(Debug, Clone)]
115pub struct SqlValidator {
116 dialect: DialectBox,
117 default_row_estimate: u64,
118}
119
120impl Default for SqlValidator {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl SqlValidator {
127 pub fn new() -> Self {
129 Self {
130 dialect: DialectBox::Generic,
131 default_row_estimate: 1000,
132 }
133 }
134
135 pub fn validate(&self, sql: &str) -> Result<SqlStatementInfo, ValidationError> {
140 let trimmed = sql.trim();
141 if trimmed.is_empty() {
142 return Err(ValidationError::ParseError {
143 message: "SQL statement is empty".to_string(),
144 line: 1,
145 column: 1,
146 });
147 }
148
149 let statements = Parser::parse_sql(self.dialect.as_dialect(), trimmed).map_err(|e| {
150 ValidationError::ParseError {
151 message: format!("SQL parse error: {}", e),
152 line: 1,
153 column: 1,
154 }
155 })?;
156
157 match statements.len() {
158 0 => Err(ValidationError::ParseError {
159 message: "SQL contains no statements".to_string(),
160 line: 1,
161 column: 1,
162 }),
163 1 => Ok(self.analyze_statement(&statements[0], trimmed)),
164 n => Err(ValidationError::ParseError {
165 message: format!("SQL Code Mode validates one statement at a time; got {}", n),
166 line: 1,
167 column: 1,
168 }),
169 }
170 }
171
172 pub fn analyze_security(&self, info: &SqlStatementInfo) -> SecurityAnalysis {
178 let mut issues: Vec<SecurityIssue> = Vec::new();
179
180 if (info.statement_type.is_write() || info.statement_type.is_delete()) && !info.has_where {
182 issues.push(SecurityIssue::new(
183 SecurityIssueType::UnboundedQuery,
184 format!(
185 "{} statement has no WHERE clause — affects all rows in the table",
186 info.verb
187 ),
188 ));
189 }
190
191 if info.statement_type.is_read_only() && !info.has_limit {
193 issues.push(SecurityIssue::new(
194 SecurityIssueType::UnboundedQuery,
195 format!(
196 "{} statement has no LIMIT — result set may be large",
197 info.verb
198 ),
199 ));
200 }
201
202 if info.join_count > 5 {
204 issues.push(SecurityIssue::new(
205 SecurityIssueType::HighComplexity,
206 format!(
207 "Query has {} JOINs, which may be expensive to execute",
208 info.join_count
209 ),
210 ));
211 }
212 if info.subquery_count > 3 {
213 issues.push(SecurityIssue::new(
214 SecurityIssueType::DeepNesting,
215 format!("Query has {} nested subqueries", info.subquery_count),
216 ));
217 }
218
219 let complexity = estimate_complexity(info);
220
221 SecurityAnalysis {
222 is_read_only: info.statement_type.is_read_only(),
223 tables_accessed: info.tables.clone(),
224 fields_accessed: info.columns.clone(),
225 has_aggregation: info.has_aggregation,
226 has_subqueries: info.subquery_count > 0,
227 estimated_complexity: complexity,
228 potential_issues: issues,
229 estimated_rows: Some(info.estimated_rows),
230 }
231 }
232
233 pub fn to_code_type(&self, info: &SqlStatementInfo) -> CodeType {
235 if info.statement_type.is_read_only() {
236 CodeType::SqlQuery
237 } else {
238 CodeType::SqlMutation
239 }
240 }
241
242 fn analyze_statement(&self, stmt: &Statement, sql: &str) -> SqlStatementInfo {
243 let mut info = SqlStatementInfo {
244 statement_type: SqlStatementType::Other,
245 verb: verb_for(stmt),
246 tables: HashSet::new(),
247 columns: HashSet::new(),
248 has_where: false,
249 has_limit: false,
250 has_order_by: false,
251 has_aggregation: false,
252 join_count: 0,
253 subquery_count: 0,
254 estimated_rows: self.default_row_estimate,
255 sql_length: sql.len(),
256 };
257
258 match stmt {
259 Statement::Query(query) => {
260 info.statement_type = SqlStatementType::Select;
261 self.analyze_query(query, &mut info);
262 },
263 Statement::Insert(insert) => {
264 info.statement_type = SqlStatementType::Insert;
265 if let TableObject::TableName(name) = &insert.table {
266 add_object_name(&mut info.tables, name);
267 }
268 for col in &insert.columns {
269 info.columns.insert(col.value.clone());
270 }
271 if let Some(source) = &insert.source {
272 self.analyze_query(source, &mut info);
273 }
274 },
275 Statement::Update(update) => {
276 info.statement_type = SqlStatementType::Update;
277 self.analyze_table_with_joins(&update.table, &mut info);
278 for assignment in &update.assignments {
279 match &assignment.target {
280 AssignmentTarget::ColumnName(name) => {
281 add_object_name(&mut info.columns, name);
282 },
283 AssignmentTarget::Tuple(names) => {
284 for n in names {
285 add_object_name(&mut info.columns, n);
286 }
287 },
288 }
289 self.analyze_expr(&assignment.value, &mut info);
290 }
291 if let Some(expr) = &update.selection {
292 info.has_where = true;
293 self.analyze_expr(expr, &mut info);
294 }
295 },
296 Statement::Delete(delete) => {
297 info.statement_type = SqlStatementType::Delete;
298 match &delete.from {
299 FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => {
300 for t in tables {
301 self.analyze_table_with_joins(t, &mut info);
302 }
303 },
304 }
305 for t in &delete.tables {
307 add_object_name(&mut info.tables, t);
308 }
309 if let Some(expr) = &delete.selection {
310 info.has_where = true;
311 self.analyze_expr(expr, &mut info);
312 }
313 },
314 Statement::Truncate(truncate) => {
315 info.statement_type = SqlStatementType::Delete;
316 for tn in &truncate.table_names {
317 add_object_name(&mut info.tables, &tn.name);
318 }
319 },
320 Statement::CreateTable(create) => {
321 info.statement_type = SqlStatementType::Ddl;
322 add_object_name(&mut info.tables, &create.name);
323 },
324 Statement::AlterTable(alter) => {
325 info.statement_type = SqlStatementType::Ddl;
326 add_object_name(&mut info.tables, &alter.name);
327 },
328 Statement::Drop { .. }
329 | Statement::CreateIndex(_)
330 | Statement::CreateView { .. }
331 | Statement::Grant { .. }
332 | Statement::Revoke { .. } => {
333 info.statement_type = SqlStatementType::Ddl;
334 },
335 _ => {
336 },
338 }
339
340 info
341 }
342
343 fn analyze_query(&self, query: &Query, info: &mut SqlStatementInfo) {
344 if query.order_by.is_some() {
345 info.has_order_by = true;
346 }
347 if let Some(limit_clause) = &query.limit_clause {
348 info.has_limit = true;
349 let limit_expr = match limit_clause {
350 LimitClause::LimitOffset { limit, .. } => limit.as_ref(),
351 LimitClause::OffsetCommaLimit { limit, .. } => Some(limit),
352 };
353 if let Some(Expr::Value(v)) = limit_expr {
354 if let sqlparser::ast::Value::Number(n, _) = &v.value {
355 if let Ok(parsed) = n.parse::<u64>() {
356 info.estimated_rows = parsed;
357 }
358 }
359 }
360 }
361
362 self.analyze_set_expr(&query.body, info);
363 }
364
365 fn analyze_set_expr(&self, set_expr: &SetExpr, info: &mut SqlStatementInfo) {
366 match set_expr {
367 SetExpr::Select(select) => self.analyze_select(select, info),
368 SetExpr::Query(inner) => {
369 info.subquery_count += 1;
370 self.analyze_query(inner, info);
371 },
372 SetExpr::SetOperation { left, right, .. } => {
373 self.analyze_set_expr(left, info);
374 self.analyze_set_expr(right, info);
375 },
376 _ => {},
377 }
378 }
379
380 fn analyze_select(&self, select: &Select, info: &mut SqlStatementInfo) {
381 for item in &select.projection {
383 match item {
384 SelectItem::UnnamedExpr(expr) => self.analyze_expr(expr, info),
385 SelectItem::ExprWithAlias { expr, .. } => self.analyze_expr(expr, info),
386 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
387 info.columns.insert("*".to_string());
388 },
389 }
390 }
391
392 for table in &select.from {
394 self.analyze_table_with_joins(table, info);
395 }
396
397 if let Some(expr) = &select.selection {
399 info.has_where = true;
400 self.analyze_expr(expr, info);
401 }
402
403 if !group_by_is_empty(&select.group_by) {
405 info.has_aggregation = true;
406 }
407 }
408
409 fn analyze_table_with_joins(&self, item: &TableWithJoins, info: &mut SqlStatementInfo) {
410 self.analyze_table_factor(&item.relation, info);
411 for join in &item.joins {
412 info.join_count += 1;
413 self.analyze_join(join, info);
414 }
415 }
416
417 fn analyze_join(&self, join: &Join, info: &mut SqlStatementInfo) {
418 self.analyze_table_factor(&join.relation, info);
419 }
420
421 fn analyze_table_factor(&self, factor: &TableFactor, info: &mut SqlStatementInfo) {
422 match factor {
423 TableFactor::Table { name, .. } => add_object_name(&mut info.tables, name),
424 TableFactor::Derived { subquery, .. } => {
425 info.subquery_count += 1;
426 self.analyze_query(subquery, info);
427 },
428 TableFactor::NestedJoin {
429 table_with_joins, ..
430 } => self.analyze_table_with_joins(table_with_joins, info),
431 _ => {},
432 }
433 }
434
435 fn analyze_expr(&self, expr: &Expr, info: &mut SqlStatementInfo) {
436 match expr {
437 Expr::Identifier(id) => {
438 info.columns.insert(id.value.clone());
439 },
440 Expr::CompoundIdentifier(ids) => {
441 if let Some(last) = ids.last() {
442 info.columns.insert(last.value.clone());
443 }
444 },
445 Expr::Subquery(q)
446 | Expr::Exists { subquery: q, .. }
447 | Expr::InSubquery { subquery: q, .. } => {
448 info.subquery_count += 1;
449 self.analyze_query(q, info);
450 },
451 Expr::Function(f) => {
452 let name = f.name.to_string().to_uppercase();
453 if matches!(
454 name.as_str(),
455 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "ARRAY_AGG" | "STRING_AGG"
456 ) {
457 info.has_aggregation = true;
458 }
459 },
460 _ => {},
461 }
462 }
463}
464
465fn estimate_complexity(info: &SqlStatementInfo) -> Complexity {
466 let joins = info.join_count;
467 let subs = info.subquery_count;
468 if joins >= 5 || subs >= 3 {
469 Complexity::High
470 } else if joins >= 2 || subs >= 1 || info.has_aggregation {
471 Complexity::Medium
472 } else {
473 Complexity::Low
474 }
475}
476
477fn group_by_is_empty(group_by: &GroupByExpr) -> bool {
478 match group_by {
479 GroupByExpr::All(_) => true,
480 GroupByExpr::Expressions(exprs, _) => exprs.is_empty(),
481 }
482}
483
484fn add_object_name(out: &mut HashSet<String>, name: &ObjectName) {
485 if let Some(last) = name.0.last() {
486 out.insert(last.to_string());
487 } else {
488 out.insert(name.to_string());
489 }
490}
491
492fn verb_for(stmt: &Statement) -> String {
493 match stmt {
494 Statement::Query(_) => "SELECT".to_string(),
495 Statement::Insert(_) => "INSERT".to_string(),
496 Statement::Update { .. } => "UPDATE".to_string(),
497 Statement::Delete(_) => "DELETE".to_string(),
498 Statement::Truncate { .. } => "TRUNCATE".to_string(),
499 Statement::CreateTable(_) => "CREATE TABLE".to_string(),
500 Statement::AlterTable { .. } => "ALTER TABLE".to_string(),
501 Statement::Drop { .. } => "DROP".to_string(),
502 Statement::CreateIndex(_) => "CREATE INDEX".to_string(),
503 Statement::CreateView { .. } => "CREATE VIEW".to_string(),
504 Statement::Grant { .. } => "GRANT".to_string(),
505 Statement::Revoke { .. } => "REVOKE".to_string(),
506 other => format!("{:?}", other)
507 .split('(')
508 .next()
509 .unwrap_or("OTHER")
510 .to_uppercase(),
511 }
512}
513
514#[derive(Debug, Clone)]
517enum DialectBox {
518 Generic,
519}
520
521impl DialectBox {
522 fn as_dialect(&self) -> &dyn Dialect {
523 match self {
524 Self::Generic => &GenericDialect {},
525 }
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[test]
534 fn select_simple() {
535 let v = SqlValidator::new();
536 let info = v.validate("SELECT id, name FROM users").unwrap();
537 assert_eq!(info.statement_type, SqlStatementType::Select);
538 assert!(info.tables.contains("users"));
539 assert!(info.columns.contains("id"));
540 assert!(info.columns.contains("name"));
541 assert!(!info.has_where);
542 assert!(!info.has_limit);
543 }
544
545 #[test]
546 fn select_with_where_limit_order() {
547 let v = SqlValidator::new();
548 let info = v
549 .validate("SELECT id FROM users WHERE active = 1 ORDER BY id LIMIT 10")
550 .unwrap();
551 assert!(info.has_where);
552 assert!(info.has_limit);
553 assert!(info.has_order_by);
554 assert_eq!(info.estimated_rows, 10);
555 }
556
557 #[test]
558 fn select_star() {
559 let v = SqlValidator::new();
560 let info = v.validate("SELECT * FROM users").unwrap();
561 assert!(info.columns.contains("*"));
562 }
563
564 #[test]
565 fn select_join_and_subquery() {
566 let v = SqlValidator::new();
567 let info = v
568 .validate(
569 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id \
570 WHERE u.id IN (SELECT id FROM admins)",
571 )
572 .unwrap();
573 assert_eq!(info.join_count, 1);
574 assert!(info.subquery_count >= 1);
575 assert!(info.tables.contains("users"));
576 assert!(info.tables.contains("orders"));
577 assert!(info.tables.contains("admins"));
578 }
579
580 #[test]
581 fn insert_extracts_table_and_columns() {
582 let v = SqlValidator::new();
583 let info = v
584 .validate("INSERT INTO users (id, name) VALUES (1, 'Alice')")
585 .unwrap();
586 assert_eq!(info.statement_type, SqlStatementType::Insert);
587 assert!(info.tables.contains("users"));
588 assert!(info.columns.contains("id"));
589 assert!(info.columns.contains("name"));
590 }
591
592 #[test]
593 fn update_without_where_flagged() {
594 let v = SqlValidator::new();
595 let info = v.validate("UPDATE users SET active = 0").unwrap();
596 assert_eq!(info.statement_type, SqlStatementType::Update);
597 assert!(!info.has_where);
598 let sa = v.analyze_security(&info);
599 assert!(sa
600 .potential_issues
601 .iter()
602 .any(|i| i.issue_type == SecurityIssueType::UnboundedQuery));
603 }
604
605 #[test]
606 fn update_with_where() {
607 let v = SqlValidator::new();
608 let info = v
609 .validate("UPDATE users SET active = 0 WHERE id = 1")
610 .unwrap();
611 assert_eq!(info.statement_type, SqlStatementType::Update);
612 assert!(info.has_where);
613 assert!(info.columns.contains("active"));
614 }
615
616 #[test]
617 fn delete_with_where() {
618 let v = SqlValidator::new();
619 let info = v.validate("DELETE FROM users WHERE id = 1").unwrap();
620 assert_eq!(info.statement_type, SqlStatementType::Delete);
621 assert!(info.has_where);
622 }
623
624 #[test]
625 fn ddl_is_admin() {
626 let v = SqlValidator::new();
627 let info = v.validate("CREATE TABLE foo (id INT)").unwrap();
628 assert_eq!(info.statement_type, SqlStatementType::Ddl);
629 assert!(info.statement_type.is_admin());
630 }
631
632 #[test]
633 fn empty_sql_rejected() {
634 let v = SqlValidator::new();
635 assert!(matches!(
636 v.validate(""),
637 Err(ValidationError::ParseError { .. })
638 ));
639 assert!(matches!(
640 v.validate(" "),
641 Err(ValidationError::ParseError { .. })
642 ));
643 }
644
645 #[test]
646 fn syntax_error_rejected() {
647 let v = SqlValidator::new();
648 assert!(matches!(
649 v.validate("SELEC id FRM users"),
650 Err(ValidationError::ParseError { .. })
651 ));
652 }
653
654 #[test]
655 fn multiple_statements_rejected() {
656 let v = SqlValidator::new();
657 assert!(matches!(
658 v.validate("SELECT 1; SELECT 2"),
659 Err(ValidationError::ParseError { .. })
660 ));
661 }
662
663 #[test]
664 fn aggregation_detected() {
665 let v = SqlValidator::new();
666 let info = v.validate("SELECT COUNT(*) FROM users").unwrap();
667 assert!(info.has_aggregation);
668 }
669
670 #[test]
671 fn group_by_detected() {
672 let v = SqlValidator::new();
673 let info = v
674 .validate("SELECT role, COUNT(*) FROM users GROUP BY role")
675 .unwrap();
676 assert!(info.has_aggregation);
677 }
678}