1use std::collections::HashMap;
4use strsim::levenshtein;
5
6#[derive(Debug, Clone)]
8pub struct Validator {
9 tables: Vec<String>,
10 columns: HashMap<String, Vec<String>>,
11}
12
13impl Validator {
14 pub fn new() -> Self {
16 Self {
17 tables: Vec::new(),
18 columns: HashMap::new(),
19 }
20 }
21
22 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
24 self.tables.push(table.to_string());
25 self.columns.insert(
26 table.to_string(),
27 cols.iter().map(|s| s.to_string()).collect(),
28 );
29 }
30
31 pub fn table_names(&self) -> &[String] {
33 &self.tables
34 }
35
36 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
38 self.columns.get(table)
39 }
40
41 pub fn validate_table(&self, table: &str) -> Result<(), String> {
43 if self.tables.contains(&table.to_string()) {
44 Ok(())
45 } else {
46 let suggestions = self.did_you_mean(table, &self.tables);
47 if let Some(sugg) = suggestions {
48 Err(format!("Table '{}' not found. Did you mean '{}'?", table, sugg))
49 } else {
50 Err(format!("Table '{}' not found.", table))
51 }
52 }
53 }
54
55 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), String> {
57 if !self.tables.contains(&table.to_string()) {
59 return Ok(());
60 }
61
62 if let Some(cols) = self.columns.get(table) {
63 if cols.contains(&column.to_string()) || column == "*" {
65 return Ok(());
66 }
67
68 let suggestions = self.did_you_mean(column, cols);
70 if let Some(sugg) = suggestions {
71 Err(format!(
72 "Column '{}' not found in table '{}'. Did you mean '{}'?",
73 column, table, sugg
74 ))
75 } else {
76 Err(format!("Column '{}' not found in table '{}'.", column, table))
77 }
78 } else {
79 Ok(())
80 }
81 }
82
83 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
85 let mut best_match = None;
86 let mut min_dist = usize::MAX;
87
88 for cand in candidates {
89 let cand_str = cand.as_ref();
90 let dist = levenshtein(input, cand_str);
91
92 let threshold = match input.len() {
94 0..=2 => 0, 3..=5 => 2, _ => 3, };
98
99 if dist <= threshold && dist < min_dist {
100 min_dist = dist;
101 best_match = Some(cand_str.to_string());
102 }
103 }
104
105 best_match
106 }
107
108 pub fn validate_command(&self, cmd: &crate::ast::QailCmd) -> Result<(), Vec<String>> {
111 let mut errors = Vec::new();
112
113 if let Err(e) = self.validate_table(&cmd.table) {
115 errors.push(e);
116 }
117
118 for col in &cmd.columns {
120 if let crate::ast::Column::Named(name) = col {
121 if let Err(e) = self.validate_column(&cmd.table, name) {
122 errors.push(e);
123 }
124 }
125 }
126
127 for cage in &cmd.cages {
129 for cond in &cage.conditions {
130 if let Err(e) = self.validate_column(&cmd.table, &cond.column) {
131 errors.push(e);
132 }
133 }
134 }
135
136 for join in &cmd.joins {
138 if let Err(e) = self.validate_table(&join.table) {
139 errors.push(e);
140 }
141 }
142
143 if errors.is_empty() {
144 Ok(())
145 } else {
146 Err(errors)
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_did_you_mean_table() {
157 let mut v = Validator::new();
158 v.add_table("users", &["id", "name"]);
159 v.add_table("orders", &["id", "total"]);
160
161 assert!(v.validate_table("users").is_ok());
162
163 let err = v.validate_table("usr").unwrap_err();
164 assert!(err.contains("Did you mean 'users'?")); let err = v.validate_table("usrs").unwrap_err();
167 assert!(err.contains("Did you mean 'users'?")); }
169
170 #[test]
171 fn test_did_you_mean_column() {
172 let mut v = Validator::new();
173 v.add_table("users", &["email", "password"]);
174
175 assert!(v.validate_column("users", "email").is_ok());
176
177 let err = v.validate_column("users", "emial").unwrap_err();
178 assert!(err.contains("Did you mean 'email'?"));
179 }
180}