1use crate::ast::{Expr, QailCmd};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
15 table: String,
16 suggestion: Option<String>,
17 },
18 ColumnNotFound {
20 table: String,
21 column: String,
22 suggestion: Option<String>,
23 },
24 TypeMismatch {
26 table: String,
27 column: String,
28 expected: String,
29 got: String,
30 },
31 InvalidOperator {
33 column: String,
34 operator: String,
35 reason: String,
36 },
37}
38
39impl std::fmt::Display for ValidationError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 ValidationError::TableNotFound { table, suggestion } => {
43 if let Some(s) = suggestion {
44 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
45 } else {
46 write!(f, "Table '{}' not found.", table)
47 }
48 }
49 ValidationError::ColumnNotFound {
50 table,
51 column,
52 suggestion,
53 } => {
54 if let Some(s) = suggestion {
55 write!(
56 f,
57 "Column '{}' not found in table '{}'. Did you mean '{}'?",
58 column, table, s
59 )
60 } else {
61 write!(f, "Column '{}' not found in table '{}'.", column, table)
62 }
63 }
64 ValidationError::TypeMismatch {
65 table,
66 column,
67 expected,
68 got,
69 } => {
70 write!(
71 f,
72 "Type mismatch for '{}.{}': expected {}, got {}",
73 table, column, expected, got
74 )
75 }
76 ValidationError::InvalidOperator {
77 column,
78 operator,
79 reason,
80 } => {
81 write!(
82 f,
83 "Invalid operator '{}' for column '{}': {}",
84 operator, column, reason
85 )
86 }
87 }
88 }
89}
90
91impl std::error::Error for ValidationError {}
92
93pub type ValidationResult = Result<(), Vec<ValidationError>>;
95
96#[derive(Debug, Clone)]
98pub struct Validator {
99 tables: Vec<String>,
100 columns: HashMap<String, Vec<String>>,
101 #[allow(dead_code)]
103 column_types: HashMap<String, HashMap<String, String>>,
104}
105
106impl Default for Validator {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl Validator {
113 pub fn new() -> Self {
115 Self {
116 tables: Vec::new(),
117 columns: HashMap::new(),
118 column_types: HashMap::new(),
119 }
120 }
121
122 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
124 self.tables.push(table.to_string());
125 self.columns.insert(
126 table.to_string(),
127 cols.iter().map(|s| s.to_string()).collect(),
128 );
129 }
130
131 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
133 self.tables.push(table.to_string());
134 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
135 self.columns.insert(table.to_string(), col_names);
136
137 let type_map: HashMap<String, String> = cols
138 .iter()
139 .map(|(name, typ)| (name.to_string(), typ.to_string()))
140 .collect();
141 self.column_types.insert(table.to_string(), type_map);
142 }
143
144 pub fn table_names(&self) -> &[String] {
146 &self.tables
147 }
148
149 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
151 self.columns.get(table)
152 }
153
154 pub fn table_exists(&self, table: &str) -> bool {
156 self.tables.contains(&table.to_string())
157 }
158
159 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
161 if self.tables.contains(&table.to_string()) {
162 Ok(())
163 } else {
164 let suggestion = self.did_you_mean(table, &self.tables);
165 Err(ValidationError::TableNotFound {
166 table: table.to_string(),
167 suggestion,
168 })
169 }
170 }
171
172 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
174 if !self.tables.contains(&table.to_string()) {
176 return Ok(());
177 }
178
179 if column == "*" || column.contains('.') {
181 return Ok(());
182 }
183
184 if let Some(cols) = self.columns.get(table) {
185 if cols.contains(&column.to_string()) {
186 Ok(())
187 } else {
188 let suggestion = self.did_you_mean(column, cols);
189 Err(ValidationError::ColumnNotFound {
190 table: table.to_string(),
191 column: column.to_string(),
192 suggestion,
193 })
194 }
195 } else {
196 Ok(())
197 }
198 }
199
200 fn extract_column_name(expr: &Expr) -> Option<String> {
202 match expr {
203 Expr::Named(name) => Some(name.clone()),
204 Expr::Aliased { name, .. } => Some(name.clone()),
205 Expr::Aggregate { col, .. } => Some(col.clone()),
206 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
207 Expr::JsonAccess { column, .. } => Some(column.clone()),
208 _ => None,
209 }
210 }
211
212 pub fn validate_command(&self, cmd: &QailCmd) -> ValidationResult {
215 let mut errors = Vec::new();
216
217 if let Err(e) = self.validate_table(&cmd.table) {
219 errors.push(e);
220 }
221
222 for col in &cmd.columns {
224 if let Some(name) = Self::extract_column_name(col)
225 && let Err(e) = self.validate_column(&cmd.table, &name)
226 {
227 errors.push(e);
228 }
229 }
230
231 for cage in &cmd.cages {
233 for cond in &cage.conditions {
234 if let Some(name) = Self::extract_column_name(&cond.left) {
235 if name.contains('.') {
237 let parts: Vec<&str> = name.split('.').collect();
238 if parts.len() == 2
239 && let Err(e) = self.validate_column(parts[0], parts[1])
240 {
241 errors.push(e);
242 }
243 } else if let Err(e) = self.validate_column(&cmd.table, &name) {
244 errors.push(e);
245 }
246 }
247 }
248 }
249
250 for join in &cmd.joins {
252 if let Err(e) = self.validate_table(&join.table) {
254 errors.push(e);
255 }
256
257 if let Some(conditions) = &join.on {
259 for cond in conditions {
260 if let Some(name) = Self::extract_column_name(&cond.left)
261 && name.contains('.')
262 {
263 let parts: Vec<&str> = name.split('.').collect();
264 if parts.len() == 2
265 && let Err(e) = self.validate_column(parts[0], parts[1])
266 {
267 errors.push(e);
268 }
269 }
270 if let crate::ast::Value::Column(col_name) = &cond.value
272 && col_name.contains('.')
273 {
274 let parts: Vec<&str> = col_name.split('.').collect();
275 if parts.len() == 2
276 && let Err(e) = self.validate_column(parts[0], parts[1])
277 {
278 errors.push(e);
279 }
280 }
281 }
282 }
283 }
284
285 if let Some(returning) = &cmd.returning {
287 for col in returning {
288 if let Some(name) = Self::extract_column_name(col)
289 && let Err(e) = self.validate_column(&cmd.table, &name)
290 {
291 errors.push(e);
292 }
293 }
294 }
295
296 if errors.is_empty() {
297 Ok(())
298 } else {
299 Err(errors)
300 }
301 }
302
303 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
305 let mut best_match = None;
306 let mut min_dist = usize::MAX;
307
308 for cand in candidates {
309 let cand_str = cand.as_ref();
310 let dist = levenshtein(input, cand_str);
311
312 let threshold = match input.len() {
314 0..=2 => 0, 3..=5 => 2, _ => 3, };
318
319 if dist <= threshold && dist < min_dist {
320 min_dist = dist;
321 best_match = Some(cand_str.to_string());
322 }
323 }
324
325 best_match
326 }
327
328 #[deprecated(note = "Use validate_table() which returns ValidationError")]
334 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
335 self.validate_table(table).map_err(|e| e.to_string())
336 }
337
338 #[deprecated(note = "Use validate_column() which returns ValidationError")]
340 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
341 self.validate_column(table, column)
342 .map_err(|e| e.to_string())
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_did_you_mean_table() {
352 let mut v = Validator::new();
353 v.add_table("users", &["id", "name"]);
354 v.add_table("orders", &["id", "total"]);
355
356 assert!(v.validate_table("users").is_ok());
357
358 let err = v.validate_table("usr").unwrap_err();
359 assert!(
360 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
361 );
362
363 let err = v.validate_table("usrs").unwrap_err();
364 assert!(
365 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
366 );
367 }
368
369 #[test]
370 fn test_did_you_mean_column() {
371 let mut v = Validator::new();
372 v.add_table("users", &["email", "password"]);
373
374 assert!(v.validate_column("users", "email").is_ok());
375 assert!(v.validate_column("users", "*").is_ok());
376
377 let err = v.validate_column("users", "emial").unwrap_err();
378 assert!(
379 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
380 );
381 }
382
383 #[test]
384 fn test_qualified_column_name() {
385 let mut v = Validator::new();
386 v.add_table("users", &["id", "name"]);
387 v.add_table("profiles", &["user_id", "avatar"]);
388
389 assert!(v.validate_column("users", "users.id").is_ok());
391 assert!(v.validate_column("users", "profiles.user_id").is_ok());
392 }
393
394 #[test]
395 fn test_validate_command() {
396 let mut v = Validator::new();
397 v.add_table("users", &["id", "email", "name"]);
398
399 let cmd = QailCmd::get("users").columns(["id", "email"]);
400 assert!(v.validate_command(&cmd).is_ok());
401
402 let cmd = QailCmd::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
404 assert_eq!(errors.len(), 1);
405 assert!(
406 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
407 );
408 }
409
410 #[test]
411 fn test_error_display() {
412 let err = ValidationError::TableNotFound {
413 table: "usrs".to_string(),
414 suggestion: Some("users".to_string()),
415 };
416 assert_eq!(
417 err.to_string(),
418 "Table 'usrs' not found. Did you mean 'users'?"
419 );
420
421 let err = ValidationError::ColumnNotFound {
422 table: "users".to_string(),
423 column: "emial".to_string(),
424 suggestion: Some("email".to_string()),
425 };
426 assert_eq!(
427 err.to_string(),
428 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
429 );
430 }
431}