use std::collections::HashMap;
use crate::error::{Result, SqlError};
use crate::parser::normalize::table_name_from_factor;
use crate::types::{CollectionInfo, EngineType, SqlCatalog};
#[derive(Debug, Clone)]
pub struct ResolvedTable {
pub name: String,
pub alias: Option<String>,
pub info: CollectionInfo,
}
impl ResolvedTable {
pub fn ref_name(&self) -> &str {
self.alias.as_deref().unwrap_or(&self.name)
}
}
#[derive(Debug, Default)]
pub struct TableScope {
pub tables: HashMap<String, ResolvedTable>,
order: Vec<String>,
}
impl TableScope {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, table: ResolvedTable) -> Result<()> {
let key = table.ref_name().to_string();
if self.tables.contains_key(&key) {
return Err(SqlError::Parse {
detail: format!("duplicate table reference: {key}"),
});
}
self.order.push(key.clone());
self.tables.insert(key, table);
Ok(())
}
pub fn resolve_column(
&self,
table_ref: Option<&str>,
column: &str,
) -> Result<(String, String)> {
let col = column.to_lowercase();
if let Some(tref) = table_ref {
let tref_lower = tref.to_lowercase();
let table = self
.tables
.get(&tref_lower)
.ok_or_else(|| SqlError::UnknownTable {
name: tref_lower.clone(),
})?;
self.validate_column(table, &col)?;
return Ok((table.name.clone(), col));
}
let mut matches = Vec::new();
for key in &self.order {
let table = &self.tables[key];
if self.column_exists(table, &col) {
matches.push(table.name.clone());
}
}
match matches.len() {
0 => {
if self.tables.len() == 1 {
let table = self.tables.values().next().unwrap();
if table.info.engine == EngineType::DocumentSchemaless {
return Ok((table.name.clone(), col));
}
}
Err(SqlError::UnknownColumn {
table: self
.order
.first()
.cloned()
.unwrap_or_else(|| "<unknown>".into()),
column: col,
})
}
1 => Ok((matches.into_iter().next().unwrap(), col)),
_ => Err(SqlError::AmbiguousColumn { column: col }),
}
}
fn column_exists(&self, table: &ResolvedTable, column: &str) -> bool {
if table.info.engine == EngineType::DocumentSchemaless {
return true;
}
table.info.columns.iter().any(|c| c.name == column)
}
fn validate_column(&self, table: &ResolvedTable, column: &str) -> Result<()> {
if self.column_exists(table, column) {
Ok(())
} else {
Err(SqlError::UnknownColumn {
table: table.name.clone(),
column: column.into(),
})
}
}
pub fn single_table(&self) -> Option<&ResolvedTable> {
if self.tables.len() == 1 {
self.tables.values().next()
} else {
Option::None
}
}
pub fn resolve_from(
catalog: &dyn SqlCatalog,
from: &[sqlparser::ast::TableWithJoins],
) -> Result<Self> {
let mut scope = Self::new();
for table_with_joins in from {
scope.resolve_table_factor(catalog, &table_with_joins.relation)?;
for join in &table_with_joins.joins {
scope.resolve_table_factor(catalog, &join.relation)?;
}
}
Ok(scope)
}
fn resolve_table_factor(
&mut self,
catalog: &dyn SqlCatalog,
factor: &sqlparser::ast::TableFactor,
) -> Result<()> {
if let Some((name, alias)) = table_name_from_factor(factor) {
let info = catalog
.get_collection(&name)?
.ok_or_else(|| SqlError::UnknownTable { name: name.clone() })?;
self.add(ResolvedTable { name, alias, info })?;
}
Ok(())
}
}