1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
14 table: String,
15 suggestion: Option<String>,
16 },
17 ColumnNotFound {
18 table: String,
19 column: String,
20 suggestion: Option<String>,
21 },
22 TypeMismatch {
24 table: String,
25 column: String,
26 expected: String,
27 got: String,
28 },
29 InvalidOperator {
31 column: String,
32 operator: String,
33 reason: String,
34 },
35}
36
37impl std::fmt::Display for ValidationError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 ValidationError::TableNotFound { table, suggestion } => {
41 if let Some(s) = suggestion {
42 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
43 } else {
44 write!(f, "Table '{}' not found.", table)
45 }
46 }
47 ValidationError::ColumnNotFound {
48 table,
49 column,
50 suggestion,
51 } => {
52 if let Some(s) = suggestion {
53 write!(
54 f,
55 "Column '{}' not found in table '{}'. Did you mean '{}'?",
56 column, table, s
57 )
58 } else {
59 write!(f, "Column '{}' not found in table '{}'.", column, table)
60 }
61 }
62 ValidationError::TypeMismatch {
63 table,
64 column,
65 expected,
66 got,
67 } => {
68 write!(
69 f,
70 "Type mismatch for '{}.{}': expected {}, got {}",
71 table, column, expected, got
72 )
73 }
74 ValidationError::InvalidOperator {
75 column,
76 operator,
77 reason,
78 } => {
79 write!(
80 f,
81 "Invalid operator '{}' for column '{}': {}",
82 operator, column, reason
83 )
84 }
85 }
86 }
87}
88
89impl std::error::Error for ValidationError {}
90
91pub type ValidationResult = Result<(), Vec<ValidationError>>;
93
94#[derive(Debug, Clone)]
96pub struct Validator {
97 tables: Vec<String>,
98 columns: HashMap<String, Vec<String>>,
99 #[allow(dead_code)]
100 column_types: HashMap<String, HashMap<String, String>>,
101}
102
103impl Default for Validator {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl Validator {
110 pub fn new() -> Self {
112 Self {
113 tables: Vec::new(),
114 columns: HashMap::new(),
115 column_types: HashMap::new(),
116 }
117 }
118
119 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
121 self.tables.push(table.to_string());
122 self.columns.insert(
123 table.to_string(),
124 cols.iter().map(|s| s.to_string()).collect(),
125 );
126 }
127
128 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
130 self.tables.push(table.to_string());
131 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
132 self.columns.insert(table.to_string(), col_names);
133
134 let type_map: HashMap<String, String> = cols
135 .iter()
136 .map(|(name, typ)| (name.to_string(), typ.to_string()))
137 .collect();
138 self.column_types.insert(table.to_string(), type_map);
139 }
140
141 pub fn table_names(&self) -> &[String] {
143 &self.tables
144 }
145
146 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
148 self.columns.get(table)
149 }
150
151 pub fn table_exists(&self, table: &str) -> bool {
153 self.tables.contains(&table.to_string())
154 }
155
156 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
158 if self.tables.contains(&table.to_string()) {
159 Ok(())
160 } else {
161 let suggestion = self.did_you_mean(table, &self.tables);
162 Err(ValidationError::TableNotFound {
163 table: table.to_string(),
164 suggestion,
165 })
166 }
167 }
168
169 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
171 if !self.tables.contains(&table.to_string()) {
173 return Ok(());
174 }
175
176 if column == "*" || column.contains('.') {
178 return Ok(());
179 }
180
181 if let Some(cols) = self.columns.get(table) {
182 if cols.contains(&column.to_string()) {
183 Ok(())
184 } else {
185 let suggestion = self.did_you_mean(column, cols);
186 Err(ValidationError::ColumnNotFound {
187 table: table.to_string(),
188 column: column.to_string(),
189 suggestion,
190 })
191 }
192 } else {
193 Ok(())
194 }
195 }
196
197 fn extract_column_name(expr: &Expr) -> Option<String> {
199 match expr {
200 Expr::Named(name) => Some(name.clone()),
201 Expr::Aliased { name, .. } => Some(name.clone()),
202 Expr::Aggregate { col, .. } => Some(col.clone()),
203 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
204 Expr::JsonAccess { column, .. } => Some(column.clone()),
205 _ => None,
206 }
207 }
208
209 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
211 let mut errors = Vec::new();
212
213 if let Err(e) = self.validate_table(&cmd.table) {
214 errors.push(e);
215 }
216
217 for col in &cmd.columns {
218 if let Some(name) = Self::extract_column_name(col)
219 && let Err(e) = self.validate_column(&cmd.table, &name)
220 {
221 errors.push(e);
222 }
223 }
224
225 for cage in &cmd.cages {
226 for cond in &cage.conditions {
227 if let Some(name) = Self::extract_column_name(&cond.left) {
228 if name.contains('.') {
230 let parts: Vec<&str> = name.split('.').collect();
231 if parts.len() == 2
232 && let Err(e) = self.validate_column(parts[0], parts[1])
233 {
234 errors.push(e);
235 }
236 } else if let Err(e) = self.validate_column(&cmd.table, &name) {
237 errors.push(e);
238 }
239 }
240 }
241 }
242
243 for join in &cmd.joins {
244 if let Err(e) = self.validate_table(&join.table) {
246 errors.push(e);
247 }
248
249 if let Some(conditions) = &join.on {
251 for cond in conditions {
252 if let Some(name) = Self::extract_column_name(&cond.left)
253 && name.contains('.')
254 {
255 let parts: Vec<&str> = name.split('.').collect();
256 if parts.len() == 2
257 && let Err(e) = self.validate_column(parts[0], parts[1])
258 {
259 errors.push(e);
260 }
261 }
262 if let crate::ast::Value::Column(col_name) = &cond.value
264 && col_name.contains('.')
265 {
266 let parts: Vec<&str> = col_name.split('.').collect();
267 if parts.len() == 2
268 && let Err(e) = self.validate_column(parts[0], parts[1])
269 {
270 errors.push(e);
271 }
272 }
273 }
274 }
275 }
276
277 if let Some(returning) = &cmd.returning {
278 for col in returning {
279 if let Some(name) = Self::extract_column_name(col)
280 && let Err(e) = self.validate_column(&cmd.table, &name)
281 {
282 errors.push(e);
283 }
284 }
285 }
286
287 if errors.is_empty() {
288 Ok(())
289 } else {
290 Err(errors)
291 }
292 }
293
294 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
296 let mut best_match = None;
297 let mut min_dist = usize::MAX;
298
299 for cand in candidates {
300 let cand_str = cand.as_ref();
301 let dist = levenshtein(input, cand_str);
302
303 let threshold = match input.len() {
305 0..=2 => 0, 3..=5 => 2, _ => 3, };
309
310 if dist <= threshold && dist < min_dist {
311 min_dist = dist;
312 best_match = Some(cand_str.to_string());
313 }
314 }
315
316 best_match
317 }
318
319 #[deprecated(note = "Use validate_table() which returns ValidationError")]
325 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
326 self.validate_table(table).map_err(|e| e.to_string())
327 }
328
329 #[deprecated(note = "Use validate_column() which returns ValidationError")]
331 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
332 self.validate_column(table, column)
333 .map_err(|e| e.to_string())
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_did_you_mean_table() {
343 let mut v = Validator::new();
344 v.add_table("users", &["id", "name"]);
345 v.add_table("orders", &["id", "total"]);
346
347 assert!(v.validate_table("users").is_ok());
348
349 let err = v.validate_table("usr").unwrap_err();
350 assert!(
351 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
352 );
353
354 let err = v.validate_table("usrs").unwrap_err();
355 assert!(
356 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
357 );
358 }
359
360 #[test]
361 fn test_did_you_mean_column() {
362 let mut v = Validator::new();
363 v.add_table("users", &["email", "password"]);
364
365 assert!(v.validate_column("users", "email").is_ok());
366 assert!(v.validate_column("users", "*").is_ok());
367
368 let err = v.validate_column("users", "emial").unwrap_err();
369 assert!(
370 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
371 );
372 }
373
374 #[test]
375 fn test_qualified_column_name() {
376 let mut v = Validator::new();
377 v.add_table("users", &["id", "name"]);
378 v.add_table("profiles", &["user_id", "avatar"]);
379
380 assert!(v.validate_column("users", "users.id").is_ok());
382 assert!(v.validate_column("users", "profiles.user_id").is_ok());
383 }
384
385 #[test]
386 fn test_validate_command() {
387 let mut v = Validator::new();
388 v.add_table("users", &["id", "email", "name"]);
389
390 let cmd = Qail::get("users").columns(["id", "email"]);
391 assert!(v.validate_command(&cmd).is_ok());
392
393 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
395 assert_eq!(errors.len(), 1);
396 assert!(
397 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
398 );
399 }
400
401 #[test]
402 fn test_error_display() {
403 let err = ValidationError::TableNotFound {
404 table: "usrs".to_string(),
405 suggestion: Some("users".to_string()),
406 };
407 assert_eq!(
408 err.to_string(),
409 "Table 'usrs' not found. Did you mean 'users'?"
410 );
411
412 let err = ValidationError::ColumnNotFound {
413 table: "users".to_string(),
414 column: "emial".to_string(),
415 suggestion: Some("email".to_string()),
416 };
417 assert_eq!(
418 err.to_string(),
419 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
420 );
421 }
422}