1use featherdb_catalog::Table;
4use featherdb_core::{find_best_match, levenshtein_distance, Error, Result};
5use std::sync::Arc;
6
7#[derive(Debug, Clone, Default)]
9pub struct PlannerContext {
10 tables: Vec<(String, Arc<Table>)>,
12 sql: String,
14 outer_columns: Vec<String>,
18}
19
20impl PlannerContext {
21 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_sql(mut self, sql: impl Into<String>) -> Self {
28 self.sql = sql.into();
29 self
30 }
31
32 pub fn add_table(&mut self, alias: Option<String>, table: Arc<Table>) {
34 let name = alias.unwrap_or_else(|| table.name.clone());
35 self.tables.push((name, table));
36 }
37
38 pub fn with_outer_columns(mut self, columns: Vec<String>) -> Self {
40 self.outer_columns = columns;
41 self
42 }
43
44 pub fn current_columns(&self) -> Vec<String> {
46 self.available_columns()
47 }
48
49 pub fn is_outer_column(&self, table: Option<&str>, column: &str) -> bool {
52 if self.find_column(table, column).is_ok() {
54 return false; }
56
57 if !self.outer_columns.is_empty() {
59 let col_name = if let Some(t) = table {
60 format!("{}.{}", t, column)
61 } else {
62 column.to_string()
63 };
64
65 self.outer_columns
67 .iter()
68 .any(|c| c == &col_name || c.ends_with(&format!(".{}", column)))
69 } else {
70 false
71 }
72 }
73
74 pub fn find_column(&self, table: Option<&str>, column: &str) -> Result<(Arc<Table>, usize)> {
76 if let Some(table_name) = table {
77 for (alias, tbl) in &self.tables {
79 if alias.eq_ignore_ascii_case(table_name)
80 || tbl.name.eq_ignore_ascii_case(table_name)
81 {
82 return match tbl.get_column_index(column) {
83 Some(idx) => Ok((tbl.clone(), idx)),
84 None => Err(self.column_not_found_error(column, &tbl.name)),
85 };
86 }
87 }
88 Err(self.table_not_found_error(table_name))
89 } else {
90 let mut found = Vec::new();
92 for (alias, tbl) in &self.tables {
93 if let Some(idx) = tbl.get_column_index(column) {
94 found.push((alias.clone(), tbl.clone(), idx));
95 }
96 }
97
98 match found.len() {
99 0 => Err(self.column_not_found_in_all_tables_error(column)),
100 1 => Ok((found[0].1.clone(), found[0].2)),
101 _ => {
102 let tables: Vec<_> = found.iter().map(|(a, _, _)| a.clone()).collect();
103 Err(Error::AmbiguousColumn {
104 column: column.to_string(),
105 tables: tables.join(", "),
106 })
107 }
108 }
109 }
110 }
111
112 fn column_not_found_error(&self, column: &str, table_name: &str) -> Error {
114 let suggestion = self
116 .tables
117 .iter()
118 .find(|(_, t)| t.name == table_name)
119 .and_then(|(_, t)| {
120 let columns: Vec<&str> = t.columns.iter().map(|c| c.name.as_str()).collect();
121 find_best_match(column, &columns, 3)
122 })
123 .map(|s| s.to_string());
124
125 Error::ColumnNotFound {
126 column: column.to_string(),
127 table: table_name.to_string(),
128 suggestion,
129 }
130 }
131
132 fn column_not_found_in_all_tables_error(&self, column: &str) -> Error {
134 let mut best_match: Option<(String, &str)> = None;
136 let mut best_distance = usize::MAX;
137
138 for (alias, table) in &self.tables {
139 for col in &table.columns {
140 let distance = levenshtein_distance(column, &col.name);
141 if distance < best_distance && distance <= 3 {
142 best_distance = distance;
143 best_match = Some((alias.clone(), col.name.as_str()));
144 }
145 }
146 }
147
148 let table_name = if self.tables.len() == 1 {
149 self.tables[0].0.clone()
150 } else {
151 "query".to_string()
152 };
153
154 let suggestion = best_match.map(|(tbl, col)| {
155 if self.tables.len() == 1 {
156 col.to_string()
157 } else {
158 format!("{}.{}", tbl, col)
159 }
160 });
161
162 Error::ColumnNotFound {
163 column: column.to_string(),
164 table: table_name,
165 suggestion,
166 }
167 }
168
169 fn table_not_found_error(&self, table: &str) -> Error {
171 let available: Vec<&str> = self.tables.iter().map(|(a, _)| a.as_str()).collect();
172 let suggestion = find_best_match(table, &available, 3).map(|s| s.to_string());
173
174 Error::TableNotFound {
175 table: table.to_string(),
176 suggestion,
177 }
178 }
179
180 pub fn available_columns(&self) -> Vec<String> {
182 let mut columns = Vec::new();
183 for (alias, table) in &self.tables {
184 for col in &table.columns {
185 if self.tables.len() == 1 {
186 columns.push(col.name.clone());
187 } else {
188 columns.push(format!("{}.{}", alias, col.name));
189 }
190 }
191 }
192 columns
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use featherdb_catalog::TableBuilder;
200
201 fn create_test_table(name: &str, columns: Vec<&str>) -> Table {
202 let mut builder = TableBuilder::new(name);
203 for col in columns {
204 builder = builder.column(col, featherdb_core::ColumnType::Integer);
205 }
206 builder.build(0, featherdb_core::PageId::INVALID)
207 }
208
209 #[test]
210 fn test_planner_context_column_lookup() {
211 let users = Arc::new(create_test_table("users", vec!["id", "name", "age"]));
212 let orders = Arc::new(create_test_table("orders", vec!["id", "user_id", "total"]));
213
214 let mut ctx = PlannerContext::new();
215 ctx.add_table(None, users.clone());
216 ctx.add_table(None, orders.clone());
217
218 let (table, idx) = ctx.find_column(Some("users"), "name").unwrap();
220 assert_eq!(table.name, "users");
221 assert_eq!(idx, 1);
222
223 let result = ctx.find_column(None, "id");
225 assert!(matches!(result, Err(Error::AmbiguousColumn { .. })));
226
227 let (table, idx) = ctx.find_column(None, "age").unwrap();
229 assert_eq!(table.name, "users");
230 assert_eq!(idx, 2);
231 }
232
233 #[test]
234 fn test_planner_context_available_columns() {
235 let users = Arc::new(create_test_table("users", vec!["id", "name"]));
236 let mut ctx = PlannerContext::new();
237 ctx.add_table(None, users);
238
239 let cols = ctx.available_columns();
240 assert_eq!(cols, vec!["id", "name"]);
241 }
242}