1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
14 table: String,
15 suggestion: Option<String>,
16 },
17 ColumnNotFound {
18 table: String,
19 column: String,
20 suggestion: Option<String>,
21 },
22 TypeMismatch {
24 table: String,
25 column: String,
26 expected: String,
27 got: String,
28 },
29 InvalidOperator {
31 column: String,
32 operator: String,
33 reason: String,
34 },
35}
36
37impl std::fmt::Display for ValidationError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 ValidationError::TableNotFound { table, suggestion } => {
41 if let Some(s) = suggestion {
42 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
43 } else {
44 write!(f, "Table '{}' not found.", table)
45 }
46 }
47 ValidationError::ColumnNotFound {
48 table,
49 column,
50 suggestion,
51 } => {
52 if let Some(s) = suggestion {
53 write!(
54 f,
55 "Column '{}' not found in table '{}'. Did you mean '{}'?",
56 column, table, s
57 )
58 } else {
59 write!(f, "Column '{}' not found in table '{}'.", column, table)
60 }
61 }
62 ValidationError::TypeMismatch {
63 table,
64 column,
65 expected,
66 got,
67 } => {
68 write!(
69 f,
70 "Type mismatch for '{}.{}': expected {}, got {}",
71 table, column, expected, got
72 )
73 }
74 ValidationError::InvalidOperator {
75 column,
76 operator,
77 reason,
78 } => {
79 write!(
80 f,
81 "Invalid operator '{}' for column '{}': {}",
82 operator, column, reason
83 )
84 }
85 }
86 }
87}
88
89impl std::error::Error for ValidationError {}
90
91pub type ValidationResult = Result<(), Vec<ValidationError>>;
93
94#[derive(Debug, Clone)]
96pub struct Validator {
97 tables: Vec<String>,
98 columns: HashMap<String, Vec<String>>,
99 #[allow(dead_code)]
100 column_types: HashMap<String, HashMap<String, String>>,
101}
102
103impl Default for Validator {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl Validator {
110 pub fn new() -> Self {
112 Self {
113 tables: Vec::new(),
114 columns: HashMap::new(),
115 column_types: HashMap::new(),
116 }
117 }
118
119 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
121 self.tables.push(table.to_string());
122 self.columns.insert(
123 table.to_string(),
124 cols.iter().map(|s| s.to_string()).collect(),
125 );
126 }
127
128 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
130 self.tables.push(table.to_string());
131 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
132 self.columns.insert(table.to_string(), col_names);
133
134 let type_map: HashMap<String, String> = cols
135 .iter()
136 .map(|(name, typ)| (name.to_string(), typ.to_string()))
137 .collect();
138 self.column_types.insert(table.to_string(), type_map);
139 }
140
141 pub fn table_names(&self) -> &[String] {
143 &self.tables
144 }
145
146 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
148 self.columns.get(table)
149 }
150
151 pub fn table_exists(&self, table: &str) -> bool {
153 self.tables.contains(&table.to_string())
154 }
155
156 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
158 if self.tables.contains(&table.to_string()) {
159 Ok(())
160 } else {
161 let suggestion = self.did_you_mean(table, &self.tables);
162 Err(ValidationError::TableNotFound {
163 table: table.to_string(),
164 suggestion,
165 })
166 }
167 }
168
169 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
171 if !self.tables.contains(&table.to_string()) {
173 return Ok(());
174 }
175
176 if column == "*" || column.contains('.') {
178 return Ok(());
179 }
180
181 if let Some(cols) = self.columns.get(table) {
182 if cols.contains(&column.to_string()) {
183 Ok(())
184 } else {
185 let suggestion = self.did_you_mean(column, cols);
186 Err(ValidationError::ColumnNotFound {
187 table: table.to_string(),
188 column: column.to_string(),
189 suggestion,
190 })
191 }
192 } else {
193 Ok(())
194 }
195 }
196
197 fn extract_column_name(expr: &Expr) -> Option<String> {
199 match expr {
200 Expr::Named(name) => Some(name.clone()),
201 Expr::Aliased { name, .. } => Some(name.clone()),
202 Expr::Aggregate { col, .. } => Some(col.clone()),
203 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
204 Expr::JsonAccess { column, .. } => Some(column.clone()),
205 _ => None,
206 }
207 }
208
209 pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
211 self.column_types.get(table)?.get(column)
212 }
213
214 pub fn validate_value_type(
217 &self,
218 table: &str,
219 column: &str,
220 value: &crate::ast::Value,
221 ) -> Result<(), ValidationError> {
222 use crate::ast::Value;
223
224 let expected_type = match self.get_column_type(table, column) {
226 Some(t) => t.to_uppercase(),
227 None => return Ok(()), };
229
230 if matches!(value, Value::Null | Value::NullUuid) {
232 return Ok(());
233 }
234
235 if matches!(value, Value::Param(_) | Value::NamedParam(_) | Value::Function(_) | Value::Subquery(_) | Value::Expr(_)) {
237 return Ok(());
238 }
239
240 let value_type = match value {
242 Value::Bool(_) => "BOOLEAN",
243 Value::Int(_) => "INT",
244 Value::Float(_) => "FLOAT",
245 Value::String(_) => "TEXT",
246 Value::Uuid(_) => "UUID",
247 Value::Array(_) => "ARRAY",
248 Value::Column(_) => return Ok(()), Value::Interval { .. } => "INTERVAL",
250 Value::Timestamp(_) => "TIMESTAMP",
251 Value::Bytes(_) => "BYTEA",
252 Value::Vector(_) => "VECTOR",
253 _ => return Ok(()), };
255
256 if !Self::types_compatible(&expected_type, value_type) {
258 return Err(ValidationError::TypeMismatch {
259 table: table.to_string(),
260 column: column.to_string(),
261 expected: expected_type,
262 got: value_type.to_string(),
263 });
264 }
265
266 Ok(())
267 }
268
269 fn types_compatible(expected: &str, value_type: &str) -> bool {
272 let expected = expected.to_uppercase();
273 let value_type = value_type.to_uppercase();
274
275 if expected == value_type {
277 return true;
278 }
279
280 let int_types = ["INT", "INT4", "INT8", "INTEGER", "BIGINT", "SMALLINT", "SERIAL", "BIGSERIAL"];
282 if int_types.contains(&expected.as_str()) && value_type == "INT" {
283 return true;
284 }
285
286 let float_types = ["FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DECIMAL", "NUMERIC", "REAL"];
288 if float_types.contains(&expected.as_str()) && (value_type == "FLOAT" || value_type == "INT") {
289 return true;
290 }
291
292 let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
294 if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
295 return true;
296 }
297
298 if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
300 return true;
301 }
302
303 if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
305 return true;
306 }
307
308 let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
310 if ts_types.contains(&expected.as_str()) && (value_type == "TIMESTAMP" || value_type == "TEXT") {
311 return true;
312 }
313
314 if expected == "JSONB" || expected == "JSON" {
316 return true;
317 }
318
319 if expected.contains("[]") || expected.starts_with("_") {
321 return value_type == "ARRAY";
322 }
323
324 false
325 }
326
327 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
329 let mut errors = Vec::new();
330
331 if let Err(e) = self.validate_table(&cmd.table) {
332 errors.push(e);
333 }
334
335 for col in &cmd.columns {
336 if let Some(name) = Self::extract_column_name(col)
337 && let Err(e) = self.validate_column(&cmd.table, &name)
338 {
339 errors.push(e);
340 }
341 }
342
343 for cage in &cmd.cages {
344 for cond in &cage.conditions {
345 if let Some(name) = Self::extract_column_name(&cond.left) {
346 if name.contains('.') {
348 let parts: Vec<&str> = name.split('.').collect();
349 if parts.len() == 2 {
350 if let Err(e) = self.validate_column(parts[0], parts[1]) {
351 errors.push(e);
352 }
353 if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
355 errors.push(e);
356 }
357 }
358 } else {
359 if let Err(e) = self.validate_column(&cmd.table, &name) {
360 errors.push(e);
361 }
362 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
364 errors.push(e);
365 }
366 }
367 }
368 }
369 }
370
371 for join in &cmd.joins {
372
373 if let Err(e) = self.validate_table(&join.table) {
374 errors.push(e);
375 }
376
377
378 if let Some(conditions) = &join.on {
379 for cond in conditions {
380 if let Some(name) = Self::extract_column_name(&cond.left)
381 && name.contains('.')
382 {
383 let parts: Vec<&str> = name.split('.').collect();
384 if parts.len() == 2
385 && let Err(e) = self.validate_column(parts[0], parts[1])
386 {
387 errors.push(e);
388 }
389 }
390 if let crate::ast::Value::Column(col_name) = &cond.value
392 && col_name.contains('.')
393 {
394 let parts: Vec<&str> = col_name.split('.').collect();
395 if parts.len() == 2
396 && let Err(e) = self.validate_column(parts[0], parts[1])
397 {
398 errors.push(e);
399 }
400 }
401 }
402 }
403 }
404
405 if let Some(returning) = &cmd.returning {
406 for col in returning {
407 if let Some(name) = Self::extract_column_name(col)
408 && let Err(e) = self.validate_column(&cmd.table, &name)
409 {
410 errors.push(e);
411 }
412 }
413 }
414
415 if errors.is_empty() {
416 Ok(())
417 } else {
418 Err(errors)
419 }
420 }
421
422 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
424 let mut best_match = None;
425 let mut min_dist = usize::MAX;
426
427 for cand in candidates {
428 let cand_str = cand.as_ref();
429 let dist = levenshtein(input, cand_str);
430
431 let threshold = match input.len() {
433 0..=2 => 0, 3..=5 => 2, _ => 3, };
437
438 if dist <= threshold && dist < min_dist {
439 min_dist = dist;
440 best_match = Some(cand_str.to_string());
441 }
442 }
443
444 best_match
445 }
446
447 #[deprecated(note = "Use validate_table() which returns ValidationError")]
453 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
454 self.validate_table(table).map_err(|e| e.to_string())
455 }
456
457 #[deprecated(note = "Use validate_column() which returns ValidationError")]
459 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
460 self.validate_column(table, column)
461 .map_err(|e| e.to_string())
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_did_you_mean_table() {
471 let mut v = Validator::new();
472 v.add_table("users", &["id", "name"]);
473 v.add_table("orders", &["id", "total"]);
474
475 assert!(v.validate_table("users").is_ok());
476
477 let err = v.validate_table("usr").unwrap_err();
478 assert!(
479 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
480 );
481
482 let err = v.validate_table("usrs").unwrap_err();
483 assert!(
484 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
485 );
486 }
487
488 #[test]
489 fn test_did_you_mean_column() {
490 let mut v = Validator::new();
491 v.add_table("users", &["email", "password"]);
492
493 assert!(v.validate_column("users", "email").is_ok());
494 assert!(v.validate_column("users", "*").is_ok());
495
496 let err = v.validate_column("users", "emial").unwrap_err();
497 assert!(
498 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
499 );
500 }
501
502 #[test]
503 fn test_qualified_column_name() {
504 let mut v = Validator::new();
505 v.add_table("users", &["id", "name"]);
506 v.add_table("profiles", &["user_id", "avatar"]);
507
508 assert!(v.validate_column("users", "users.id").is_ok());
510 assert!(v.validate_column("users", "profiles.user_id").is_ok());
511 }
512
513 #[test]
514 fn test_validate_command() {
515 let mut v = Validator::new();
516 v.add_table("users", &["id", "email", "name"]);
517
518 let cmd = Qail::get("users").columns(["id", "email"]);
519 assert!(v.validate_command(&cmd).is_ok());
520
521 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
523 assert_eq!(errors.len(), 1);
524 assert!(
525 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
526 );
527 }
528
529 #[test]
530 fn test_error_display() {
531 let err = ValidationError::TableNotFound {
532 table: "usrs".to_string(),
533 suggestion: Some("users".to_string()),
534 };
535 assert_eq!(
536 err.to_string(),
537 "Table 'usrs' not found. Did you mean 'users'?"
538 );
539
540 let err = ValidationError::ColumnNotFound {
541 table: "users".to_string(),
542 column: "emial".to_string(),
543 suggestion: Some("email".to_string()),
544 };
545 assert_eq!(
546 err.to_string(),
547 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
548 );
549 }
550}