1use std::collections::HashMap;
7use strsim::levenshtein;
8use crate::ast::{Expr, QailCmd};
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
15 table: String,
16 suggestion: Option<String>,
17 },
18 ColumnNotFound {
20 table: String,
21 column: String,
22 suggestion: Option<String>,
23 },
24 TypeMismatch {
26 table: String,
27 column: String,
28 expected: String,
29 got: String,
30 },
31 InvalidOperator {
33 column: String,
34 operator: String,
35 reason: String,
36 },
37}
38
39impl std::fmt::Display for ValidationError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 ValidationError::TableNotFound { table, suggestion } => {
43 if let Some(s) = suggestion {
44 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
45 } else {
46 write!(f, "Table '{}' not found.", table)
47 }
48 }
49 ValidationError::ColumnNotFound { table, column, suggestion } => {
50 if let Some(s) = suggestion {
51 write!(f, "Column '{}' not found in table '{}'. Did you mean '{}'?", column, table, s)
52 } else {
53 write!(f, "Column '{}' not found in table '{}'.", column, table)
54 }
55 }
56 ValidationError::TypeMismatch { table, column, expected, got } => {
57 write!(f, "Type mismatch for '{}.{}': expected {}, got {}", table, column, expected, got)
58 }
59 ValidationError::InvalidOperator { column, operator, reason } => {
60 write!(f, "Invalid operator '{}' for column '{}': {}", operator, column, reason)
61 }
62 }
63 }
64}
65
66impl std::error::Error for ValidationError {}
67
68pub type ValidationResult = Result<(), Vec<ValidationError>>;
70
71#[derive(Debug, Clone)]
73pub struct Validator {
74 tables: Vec<String>,
75 columns: HashMap<String, Vec<String>>,
76 #[allow(dead_code)]
78 column_types: HashMap<String, HashMap<String, String>>,
79}
80
81impl Default for Validator {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl Validator {
88 pub fn new() -> Self {
90 Self {
91 tables: Vec::new(),
92 columns: HashMap::new(),
93 column_types: HashMap::new(),
94 }
95 }
96
97 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
99 self.tables.push(table.to_string());
100 self.columns.insert(
101 table.to_string(),
102 cols.iter().map(|s| s.to_string()).collect(),
103 );
104 }
105
106 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
108 self.tables.push(table.to_string());
109 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
110 self.columns.insert(table.to_string(), col_names);
111
112 let type_map: HashMap<String, String> = cols.iter()
113 .map(|(name, typ)| (name.to_string(), typ.to_string()))
114 .collect();
115 self.column_types.insert(table.to_string(), type_map);
116 }
117
118 pub fn table_names(&self) -> &[String] {
120 &self.tables
121 }
122
123 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
125 self.columns.get(table)
126 }
127
128 pub fn table_exists(&self, table: &str) -> bool {
130 self.tables.contains(&table.to_string())
131 }
132
133 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
135 if self.tables.contains(&table.to_string()) {
136 Ok(())
137 } else {
138 let suggestion = self.did_you_mean(table, &self.tables);
139 Err(ValidationError::TableNotFound {
140 table: table.to_string(),
141 suggestion,
142 })
143 }
144 }
145
146 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
148 if !self.tables.contains(&table.to_string()) {
150 return Ok(());
151 }
152
153 if column == "*" || column.contains('.') {
155 return Ok(());
156 }
157
158 if let Some(cols) = self.columns.get(table) {
159 if cols.contains(&column.to_string()) {
160 Ok(())
161 } else {
162 let suggestion = self.did_you_mean(column, cols);
163 Err(ValidationError::ColumnNotFound {
164 table: table.to_string(),
165 column: column.to_string(),
166 suggestion,
167 })
168 }
169 } else {
170 Ok(())
171 }
172 }
173
174 fn extract_column_name(expr: &Expr) -> Option<String> {
176 match expr {
177 Expr::Named(name) => Some(name.clone()),
178 Expr::Aliased { name, .. } => Some(name.clone()),
179 Expr::Aggregate { col, .. } => Some(col.clone()),
180 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
181 Expr::JsonAccess { column, .. } => Some(column.clone()),
182 _ => None,
183 }
184 }
185
186 pub fn validate_command(&self, cmd: &QailCmd) -> ValidationResult {
189 let mut errors = Vec::new();
190
191 if let Err(e) = self.validate_table(&cmd.table) {
193 errors.push(e);
194 }
195
196 for col in &cmd.columns {
198 if let Some(name) = Self::extract_column_name(col) {
199 if let Err(e) = self.validate_column(&cmd.table, &name) {
200 errors.push(e);
201 }
202 }
203 }
204
205 for cage in &cmd.cages {
207 for cond in &cage.conditions {
208 if let Some(name) = Self::extract_column_name(&cond.left) {
209 if name.contains('.') {
211 let parts: Vec<&str> = name.split('.').collect();
212 if parts.len() == 2 {
213 if let Err(e) = self.validate_column(parts[0], parts[1]) {
214 errors.push(e);
215 }
216 }
217 } else if let Err(e) = self.validate_column(&cmd.table, &name) {
218 errors.push(e);
219 }
220 }
221 }
222 }
223
224 for join in &cmd.joins {
226 if let Err(e) = self.validate_table(&join.table) {
228 errors.push(e);
229 }
230
231 if let Some(conditions) = &join.on {
233 for cond in conditions {
234 if let Some(name) = Self::extract_column_name(&cond.left) {
235 if name.contains('.') {
236 let parts: Vec<&str> = name.split('.').collect();
237 if parts.len() == 2 {
238 if let Err(e) = self.validate_column(parts[0], parts[1]) {
239 errors.push(e);
240 }
241 }
242 }
243 }
244 if let crate::ast::Value::Column(col_name) = &cond.value {
246 if col_name.contains('.') {
247 let parts: Vec<&str> = col_name.split('.').collect();
248 if parts.len() == 2 {
249 if let Err(e) = self.validate_column(parts[0], parts[1]) {
250 errors.push(e);
251 }
252 }
253 }
254 }
255 }
256 }
257 }
258
259 if let Some(returning) = &cmd.returning {
261 for col in returning {
262 if let Some(name) = Self::extract_column_name(col) {
263 if let Err(e) = self.validate_column(&cmd.table, &name) {
264 errors.push(e);
265 }
266 }
267 }
268 }
269
270 if errors.is_empty() {
271 Ok(())
272 } else {
273 Err(errors)
274 }
275 }
276
277 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
279 let mut best_match = None;
280 let mut min_dist = usize::MAX;
281
282 for cand in candidates {
283 let cand_str = cand.as_ref();
284 let dist = levenshtein(input, cand_str);
285
286 let threshold = match input.len() {
288 0..=2 => 0, 3..=5 => 2, _ => 3, };
292
293 if dist <= threshold && dist < min_dist {
294 min_dist = dist;
295 best_match = Some(cand_str.to_string());
296 }
297 }
298
299 best_match
300 }
301
302 #[deprecated(note = "Use validate_table() which returns ValidationError")]
308 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
309 self.validate_table(table).map_err(|e| e.to_string())
310 }
311
312 #[deprecated(note = "Use validate_column() which returns ValidationError")]
314 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
315 self.validate_column(table, column).map_err(|e| e.to_string())
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_did_you_mean_table() {
325 let mut v = Validator::new();
326 v.add_table("users", &["id", "name"]);
327 v.add_table("orders", &["id", "total"]);
328
329 assert!(v.validate_table("users").is_ok());
330
331 let err = v.validate_table("usr").unwrap_err();
332 assert!(matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users"));
333
334 let err = v.validate_table("usrs").unwrap_err();
335 assert!(matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users"));
336 }
337
338 #[test]
339 fn test_did_you_mean_column() {
340 let mut v = Validator::new();
341 v.add_table("users", &["email", "password"]);
342
343 assert!(v.validate_column("users", "email").is_ok());
344 assert!(v.validate_column("users", "*").is_ok());
345
346 let err = v.validate_column("users", "emial").unwrap_err();
347 assert!(matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email"));
348 }
349
350 #[test]
351 fn test_qualified_column_name() {
352 let mut v = Validator::new();
353 v.add_table("users", &["id", "name"]);
354 v.add_table("profiles", &["user_id", "avatar"]);
355
356 assert!(v.validate_column("users", "users.id").is_ok());
358 assert!(v.validate_column("users", "profiles.user_id").is_ok());
359 }
360
361 #[test]
362 fn test_validate_command() {
363 let mut v = Validator::new();
364 v.add_table("users", &["id", "email", "name"]);
365
366 let cmd = QailCmd::get("users").columns(["id", "email"]);
367 assert!(v.validate_command(&cmd).is_ok());
368
369 let cmd = QailCmd::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
371 assert_eq!(errors.len(), 1);
372 assert!(matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial"));
373 }
374
375 #[test]
376 fn test_error_display() {
377 let err = ValidationError::TableNotFound {
378 table: "usrs".to_string(),
379 suggestion: Some("users".to_string()),
380 };
381 assert_eq!(err.to_string(), "Table 'usrs' not found. Did you mean 'users'?");
382
383 let err = ValidationError::ColumnNotFound {
384 table: "users".to_string(),
385 column: "emial".to_string(),
386 suggestion: Some("email".to_string()),
387 };
388 assert_eq!(err.to_string(), "Column 'emial' not found in table 'users'. Did you mean 'email'?");
389 }
390}