use std::collections::{HashMap, HashSet};
use anyhow::Result;
use crate::ast::rq::{
fold_table, fold_table_ref, CId, ColumnDecl, ColumnDefKind, IrFold, Query, TId, TableDecl,
TableRef, Transform, Window,
};
use crate::utils::IdGenerator;
#[derive(Default)]
pub struct AnchorContext {
pub(super) columns_decls: HashMap<CId, ColumnDecl>,
pub(super) columns_loc: HashMap<CId, TIId>,
pub(super) table_decls: HashMap<TId, TableDecl>,
pub(super) table_instances: HashMap<TIId, TableRef>,
col_name: IdGenerator<usize>,
table_name: IdGenerator<usize>,
pub(super) cid: IdGenerator<CId>,
pub(super) tid: IdGenerator<TId>,
pub(super) tiid: IdGenerator<TIId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TIId(usize);
impl From<usize> for TIId {
fn from(id: usize) -> Self {
TIId(id)
}
}
impl AnchorContext {
pub fn of(query: Query) -> (Self, Query) {
let (cid, tid, query) = IdGenerator::load(query);
let context = AnchorContext {
cid,
tid,
tiid: IdGenerator::new(),
..Default::default()
};
QueryLoader::load(context, query)
}
pub fn register_wildcard(&mut self, tiid: TIId) -> CId {
let cd = self.register_column(ColumnDefKind::Wildcard, None, Some(tiid));
cd.id
}
pub fn register_column(
&mut self,
kind: ColumnDefKind,
window: Option<Window>,
tiid: Option<TIId>,
) -> ColumnDecl {
let decl = ColumnDecl {
id: self.cid.gen(),
kind,
window,
is_aggregation: false,
};
self.columns_decls.insert(decl.id, decl.clone());
if let Some(tiid) = tiid {
self.columns_loc.insert(decl.id, tiid);
}
decl
}
pub fn register_table_instance(&mut self, table_ref: TableRef) {
let tiid = self.tiid.gen();
for column in &table_ref.columns {
self.columns_decls.insert(column.id, column.clone());
self.columns_loc.insert(column.id, tiid);
}
self.table_instances.insert(tiid, table_ref);
}
pub fn get_column_name(&self, cid: &CId) -> Option<String> {
let decl = self.columns_decls.get(cid).unwrap();
decl.get_name().cloned()
}
pub fn gen_table_name(&mut self) -> String {
format!("table_{}", self.table_name.gen())
}
pub fn gen_column_name(&mut self) -> String {
format!("_expr_{}", self.col_name.gen())
}
pub fn ensure_column_name(&mut self, cid: &CId) -> String {
let decl = self.columns_decls.get_mut(cid).unwrap();
match &mut decl.kind {
ColumnDefKind::Expr { name, .. } => {
if name.is_none() {
*name = Some(format!("_expr_{}", self.col_name.gen()));
}
name.clone().unwrap()
}
ColumnDefKind::Wildcard => "*".to_string(),
ColumnDefKind::ExternRef(name) => name.clone(),
}
}
pub fn materialize_name(&mut self, cid: &CId) -> (Option<String>, String) {
let col_name = self.ensure_column_name(cid);
let table_name = self.columns_loc.get(cid).map(|tiid| {
let table = self.table_instances.get(tiid).unwrap();
if let Some(alias) = &table.name {
alias.clone()
} else {
let decl = &self.table_decls[&table.source];
decl.name.clone().unwrap()
}
});
(table_name, col_name)
}
pub fn determine_select_columns(&self, pipeline: &[Transform]) -> Vec<CId> {
let mut columns = Vec::new();
for transform in pipeline {
match transform {
Transform::From(table) => {
columns = table.columns.iter().map(|c| c.id).collect();
}
Transform::Select(cols) => columns = cols.clone(),
Transform::Aggregate { partition, compute } => {
columns = [partition.clone(), compute.clone()].concat()
}
Transform::Join { with: table, .. } => {
columns.extend(table.columns.iter().map(|c| c.id));
}
_ => {}
}
}
columns
}
pub fn collect_pipeline_inputs(&self, pipeline: &[Transform]) -> (Vec<TIId>, HashSet<CId>) {
let mut tables = Vec::new();
let mut columns = HashSet::new();
for t in pipeline {
if let Transform::From(table) | Transform::Join { with: table, .. } = t {
if let Some(column) = table.columns.first() {
tables.push(self.columns_loc[&column.id]);
} else {
panic!("table without columns?")
}
columns.extend(table.columns.iter().map(|c| c.id));
}
}
(tables, columns)
}
}
struct QueryLoader {
context: AnchorContext,
}
impl QueryLoader {
fn load(context: AnchorContext, query: Query) -> (AnchorContext, Query) {
let mut loader = QueryLoader { context };
let query = loader.fold_query(query).unwrap();
(loader.context, query)
}
}
impl IrFold for QueryLoader {
fn fold_table(&mut self, table: TableDecl) -> Result<TableDecl> {
let table = fold_table(self, table)?;
self.context.table_decls.insert(table.id, table.clone());
Ok(table)
}
fn fold_column_decl(&mut self, cd: ColumnDecl) -> Result<ColumnDecl> {
self.context.columns_decls.insert(cd.id, cd.clone());
Ok(cd)
}
fn fold_table_ref(&mut self, table_ref: TableRef) -> Result<TableRef> {
let tiid = self.context.tiid.gen();
self.context.table_instances.insert(tiid, table_ref.clone());
for col in &table_ref.columns {
self.context.columns_loc.insert(col.id, tiid);
}
fold_table_ref(self, table_ref)
}
}