use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use itertools::Itertools;
use tracing::trace;
use super::optimizer::{ExprId, GroupId, PredId};
use crate::cost::{Cost, Statistics};
use crate::nodes::{ArcPlanNode, ArcPredNode, NodeType, PlanNode, PlanNodeOrGroup};
use crate::property::PropertyBuilderAny;
pub type ArcMemoPlanNode<T> = Arc<MemoPlanNode<T>>;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MemoPlanNode<T: NodeType> {
pub typ: T,
pub children: Vec<GroupId>,
pub predicates: Vec<PredId>,
}
impl<T: NodeType> std::fmt::Display for MemoPlanNode<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}", self.typ)?;
for child in &self.children {
write!(f, " {}", child)?;
}
for pred in &self.predicates {
write!(f, " {}", pred)?;
}
write!(f, ")")
}
}
#[derive(Clone)]
pub struct WinnerInfo {
pub expr_id: ExprId,
pub total_weighted_cost: f64,
pub operation_weighted_cost: f64,
pub total_cost: Cost,
pub operation_cost: Cost,
pub statistics: Arc<Statistics>,
}
#[derive(Clone)]
pub enum Winner {
Unknown,
Impossible,
Full(WinnerInfo),
}
impl Winner {
pub fn has_full_winner(&self) -> bool {
matches!(self, Self::Full { .. })
}
pub fn has_decided(&self) -> bool {
matches!(self, Self::Full { .. } | Self::Impossible)
}
pub fn as_full_winner(&self) -> Option<&WinnerInfo> {
match self {
Self::Full(info) => Some(info),
_ => None,
}
}
}
impl Default for Winner {
fn default() -> Self {
Self::Unknown
}
}
#[derive(Default, Clone)]
pub struct GroupInfo {
pub winner: Winner,
}
pub struct Group {
pub(crate) group_exprs: HashSet<ExprId>,
pub(crate) info: GroupInfo,
pub(crate) properties: Arc<[Box<dyn Any + Send + Sync + 'static>]>,
}
pub trait Memo<T: NodeType>: 'static + Send + Sync {
fn add_new_expr(&mut self, rel_node: ArcPlanNode<T>) -> (GroupId, ExprId);
fn add_expr_to_group(
&mut self,
rel_node: PlanNodeOrGroup<T>,
group_id: GroupId,
) -> Option<ExprId>;
fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId;
fn get_group_id(&self, expr_id: ExprId) -> GroupId;
fn get_expr_memoed(&self, expr_id: ExprId) -> ArcMemoPlanNode<T>;
fn get_all_group_ids(&self) -> Vec<GroupId>;
fn get_group(&self, group_id: GroupId) -> &Group;
fn get_pred(&self, pred_id: PredId) -> ArcPredNode<T>;
fn update_group_info(&mut self, group_id: GroupId, group_info: GroupInfo);
fn estimated_plan_space(&self) -> usize;
fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
let group = self.get_group(group_id);
let mut exprs = group.group_exprs.iter().copied().collect_vec();
exprs.sort();
exprs
}
fn get_group_info(&self, group_id: GroupId) -> &GroupInfo {
&self.get_group(group_id).info
}
fn get_best_group_binding(
&self,
group_id: GroupId,
mut post_process: impl FnMut(ArcPlanNode<T>, GroupId, &WinnerInfo),
) -> Result<ArcPlanNode<T>> {
get_best_group_binding_inner(self, group_id, &mut post_process)
}
fn get_predicate_binding(&self, group_id: GroupId) -> Option<ArcPlanNode<T>> {
get_predicate_binding_group_inner(self, group_id, true)
}
fn try_get_predicate_binding(&self, group_id: GroupId) -> Option<ArcPlanNode<T>> {
get_predicate_binding_group_inner(self, group_id, false)
}
}
fn get_best_group_binding_inner<M: Memo<T> + ?Sized, T: NodeType>(
this: &M,
group_id: GroupId,
post_process: &mut impl FnMut(ArcPlanNode<T>, GroupId, &WinnerInfo),
) -> Result<ArcPlanNode<T>> {
let info: &GroupInfo = this.get_group_info(group_id);
if let Winner::Full(info @ WinnerInfo { expr_id, .. }) = &info.winner {
let expr = this.get_expr_memoed(*expr_id);
let mut children = Vec::with_capacity(expr.children.len());
for child in &expr.children {
children.push(PlanNodeOrGroup::PlanNode(
get_best_group_binding_inner(this, *child, post_process)
.with_context(|| format!("when processing expr {}", expr_id))?,
));
}
let node = Arc::new(PlanNode {
typ: expr.typ.clone(),
children,
predicates: expr.predicates.iter().map(|x| this.get_pred(*x)).collect(),
});
post_process(node.clone(), group_id, info);
return Ok(node);
}
bail!("no best group binding for group {}", group_id)
}
fn get_predicate_binding_expr_inner<M: Memo<T> + ?Sized, T: NodeType>(
this: &M,
expr_id: ExprId,
panic_on_invalid_group: bool,
) -> Option<ArcPlanNode<T>> {
let expr = this.get_expr_memoed(expr_id);
let mut children = Vec::with_capacity(expr.children.len());
for child in expr.children.iter() {
if let Some(child) = get_predicate_binding_group_inner(this, *child, panic_on_invalid_group)
{
children.push(PlanNodeOrGroup::PlanNode(child));
} else {
return None;
}
}
Some(Arc::new(PlanNode {
typ: expr.typ.clone(),
children,
predicates: expr.predicates.iter().map(|x| this.get_pred(*x)).collect(),
}))
}
fn get_predicate_binding_group_inner<M: Memo<T> + ?Sized, T: NodeType>(
this: &M,
group_id: GroupId,
panic_on_invalid_group: bool,
) -> Option<ArcPlanNode<T>> {
let exprs = this.get_all_exprs_in_group(group_id);
match exprs.len() {
0 => None,
1 => get_predicate_binding_expr_inner(
this,
exprs.first().copied().unwrap(),
panic_on_invalid_group,
),
len => {
if panic_on_invalid_group {
panic!("group {group_id} has {len} expressions")
} else {
None
}
}
}
}
pub struct NaiveMemo<T: NodeType> {
groups: HashMap<GroupId, Group>,
expr_id_to_expr_node: HashMap<ExprId, ArcMemoPlanNode<T>>,
pred_id_to_pred_node: HashMap<PredId, ArcPredNode<T>>,
pred_node_to_pred_id: HashMap<ArcPredNode<T>, PredId>,
group_expr_counter: usize,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
expr_node_to_expr_id: HashMap<MemoPlanNode<T>, ExprId>,
expr_id_to_group_id: HashMap<ExprId, GroupId>,
merged_group_mapping: HashMap<GroupId, GroupId>,
dup_expr_mapping: HashMap<ExprId, ExprId>,
}
impl<T: NodeType> Memo<T> for NaiveMemo<T> {
fn add_new_expr(&mut self, rel_node: ArcPlanNode<T>) -> (GroupId, ExprId) {
let (group_id, expr_id) = self
.add_new_group_expr_inner(rel_node, None)
.expect("should not trigger merge group");
self.verify_integrity();
(group_id, expr_id)
}
fn add_expr_to_group(
&mut self,
rel_node: PlanNodeOrGroup<T>,
group_id: GroupId,
) -> Option<ExprId> {
match rel_node {
PlanNodeOrGroup::Group(input_group) => {
let input_group = self.reduce_group(input_group);
let group_id = self.reduce_group(group_id);
self.merge_group_inner(input_group, group_id);
None
}
PlanNodeOrGroup::PlanNode(rel_node) => {
let reduced_group_id = self.reduce_group(group_id);
let (returned_group_id, expr_id) = self
.add_new_group_expr_inner(rel_node, Some(reduced_group_id))
.unwrap();
assert_eq!(returned_group_id, reduced_group_id);
self.verify_integrity();
Some(expr_id)
}
}
}
fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId {
let pred_id = self.next_pred_id();
if let Some(id) = self.pred_node_to_pred_id.get(&pred_node) {
return *id;
}
self.pred_node_to_pred_id.insert(pred_node.clone(), pred_id);
self.pred_id_to_pred_node.insert(pred_id, pred_node);
pred_id
}
fn get_pred(&self, pred_id: PredId) -> ArcPredNode<T> {
self.pred_id_to_pred_node[&pred_id].clone()
}
fn get_group_id(&self, mut expr_id: ExprId) -> GroupId {
while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) {
expr_id = *new_expr_id;
}
*self
.expr_id_to_group_id
.get(&expr_id)
.expect("expr not found in group mapping")
}
fn get_expr_memoed(&self, mut expr_id: ExprId) -> ArcMemoPlanNode<T> {
while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) {
expr_id = *new_expr_id;
}
self.expr_id_to_expr_node
.get(&expr_id)
.expect("expr not found in expr mapping")
.clone()
}
fn get_all_group_ids(&self) -> Vec<GroupId> {
let mut ids = self.groups.keys().copied().collect_vec();
ids.sort();
ids
}
fn get_group(&self, group_id: GroupId) -> &Group {
let group_id = self.reduce_group(group_id);
self.groups.get(&group_id).as_ref().unwrap()
}
fn update_group_info(&mut self, group_id: GroupId, group_info: GroupInfo) {
if let Winner::Full(WinnerInfo {
total_weighted_cost,
expr_id,
..
}) = &group_info.winner
{
assert!(
*total_weighted_cost != 0.0,
"{}",
self.expr_id_to_expr_node[expr_id]
);
}
let grp = self.groups.get_mut(&group_id);
grp.unwrap().info = group_info;
}
fn estimated_plan_space(&self) -> usize {
self.expr_id_to_expr_node.len()
}
}
impl<T: NodeType> NaiveMemo<T> {
pub fn new(property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>) -> Self {
Self {
expr_id_to_group_id: HashMap::new(),
expr_id_to_expr_node: HashMap::new(),
expr_node_to_expr_id: HashMap::new(),
pred_id_to_pred_node: HashMap::new(),
pred_node_to_pred_id: HashMap::new(),
groups: HashMap::new(),
group_expr_counter: 0,
merged_group_mapping: HashMap::new(),
property_builders,
dup_expr_mapping: HashMap::new(),
}
}
fn next_group_id(&mut self) -> GroupId {
let id = self.group_expr_counter;
self.group_expr_counter += 1;
GroupId(id)
}
fn next_expr_id(&mut self) -> ExprId {
let id = self.group_expr_counter;
self.group_expr_counter += 1;
ExprId(id)
}
fn next_pred_id(&mut self) -> PredId {
let id = self.group_expr_counter;
self.group_expr_counter += 1;
PredId(id)
}
fn verify_integrity(&self) {
if cfg!(debug_assertions) {
let num_of_exprs = self.expr_id_to_expr_node.len();
assert_eq!(num_of_exprs, self.expr_node_to_expr_id.len());
assert_eq!(num_of_exprs, self.expr_id_to_group_id.len());
let mut valid_groups = HashSet::new();
for to in self.merged_group_mapping.values() {
assert_eq!(self.merged_group_mapping[to], *to);
valid_groups.insert(*to);
}
assert_eq!(valid_groups.len(), self.groups.len());
for (id, node) in self.expr_id_to_expr_node.iter() {
assert_eq!(self.expr_node_to_expr_id[node], *id);
for child in &node.children {
assert!(
valid_groups.contains(child),
"invalid group used in expression {}, where {} does not exist any more",
node,
child
);
}
}
let mut cnt = 0;
for (group_id, group) in &self.groups {
assert!(valid_groups.contains(group_id));
cnt += group.group_exprs.len();
assert!(!group.group_exprs.is_empty());
for expr in &group.group_exprs {
assert_eq!(self.expr_id_to_group_id[expr], *group_id);
}
}
assert_eq!(cnt, num_of_exprs);
}
}
fn reduce_group(&self, group_id: GroupId) -> GroupId {
self.merged_group_mapping[&group_id]
}
fn merge_group_inner(&mut self, merge_into: GroupId, merge_from: GroupId) {
if merge_into == merge_from {
return;
}
trace!(event = "merge_group", merge_into = %merge_into, merge_from = %merge_from);
let group_merge_from = self.groups.remove(&merge_from).unwrap();
let group_merge_into = self.groups.get_mut(&merge_into).unwrap();
for from_expr in group_merge_from.group_exprs {
let ret = self.expr_id_to_group_id.insert(from_expr, merge_into);
assert!(ret.is_some());
group_merge_into.group_exprs.insert(from_expr);
}
self.merged_group_mapping.insert(merge_from, merge_into);
for (_, mapped_to) in self.merged_group_mapping.iter_mut() {
if *mapped_to == merge_from {
*mapped_to = merge_into;
}
}
let mut pending_recursive_merge = Vec::new();
for (group_id, group) in self.groups.iter_mut() {
let mut new_expr_list = HashSet::new();
for expr_id in group.group_exprs.iter() {
let expr = self.expr_id_to_expr_node[expr_id].clone();
if expr.children.contains(&merge_from) {
let old_expr = expr.as_ref().clone();
let mut new_expr = expr.as_ref().clone();
new_expr.children.iter_mut().for_each(|x| {
if *x == merge_from {
*x = merge_into;
}
});
self.expr_id_to_expr_node
.insert(*expr_id, Arc::new(new_expr.clone()));
self.expr_node_to_expr_id.remove(&old_expr);
if let Some(dup_expr) = self.expr_node_to_expr_id.get(&new_expr) {
let dup_group_id = self.expr_id_to_group_id[dup_expr];
if dup_group_id != *group_id {
pending_recursive_merge.push((dup_group_id, *group_id));
}
self.expr_id_to_expr_node.remove(expr_id);
self.expr_id_to_group_id.remove(expr_id);
self.dup_expr_mapping.insert(*expr_id, *dup_expr);
new_expr_list.insert(*dup_expr); } else {
self.expr_node_to_expr_id.insert(new_expr, *expr_id);
new_expr_list.insert(*expr_id);
}
} else {
new_expr_list.insert(*expr_id);
}
}
assert!(!new_expr_list.is_empty());
group.group_exprs = new_expr_list;
}
for (merge_from, merge_into) in pending_recursive_merge {
let merge_from = self.reduce_group(merge_from);
let merge_into = self.reduce_group(merge_into);
self.merge_group_inner(merge_into, merge_from);
}
}
fn add_new_group_expr_inner(
&mut self,
rel_node: ArcPlanNode<T>,
add_to_group_id: Option<GroupId>,
) -> anyhow::Result<(GroupId, ExprId)> {
let children_group_ids = rel_node
.children
.iter()
.map(|child| {
match child {
PlanNodeOrGroup::Group(group) => self.reduce_group(*group),
PlanNodeOrGroup::PlanNode(child) => {
let (group, _) = self
.add_new_group_expr_inner(child.clone(), None)
.expect("should not trigger merge group");
self.reduce_group(group) }
}
})
.collect::<Vec<_>>();
let memo_node = MemoPlanNode {
typ: rel_node.typ.clone(),
children: children_group_ids,
predicates: rel_node
.predicates
.iter()
.map(|x| self.add_new_pred(x.clone()))
.collect(),
};
if let Some(&expr_id) = self.expr_node_to_expr_id.get(&memo_node) {
let group_id = self.expr_id_to_group_id[&expr_id];
if let Some(add_to_group_id) = add_to_group_id {
let add_to_group_id = self.reduce_group(add_to_group_id);
self.merge_group_inner(add_to_group_id, group_id);
return Ok((add_to_group_id, expr_id));
}
return Ok((group_id, expr_id));
}
let expr_id = self.next_expr_id();
let group_id = if let Some(group_id) = add_to_group_id {
group_id
} else {
self.next_group_id()
};
self.expr_id_to_expr_node
.insert(expr_id, memo_node.clone().into());
self.expr_id_to_group_id.insert(expr_id, group_id);
self.expr_node_to_expr_id.insert(memo_node.clone(), expr_id);
self.append_expr_to_group(expr_id, group_id, memo_node);
Ok((group_id, expr_id))
}
#[cfg(test)]
pub(crate) fn get_expr_info(&self, rel_node: ArcPlanNode<T>) -> (GroupId, ExprId) {
let children_group_ids = rel_node
.children
.iter()
.map(|child| match child {
PlanNodeOrGroup::Group(group) => *group,
PlanNodeOrGroup::PlanNode(child) => self.get_expr_info(child.clone()).0,
})
.collect::<Vec<_>>();
let memo_node = MemoPlanNode {
typ: rel_node.typ.clone(),
children: children_group_ids,
predicates: rel_node
.predicates
.iter()
.map(|x| self.pred_node_to_pred_id[x])
.collect(),
};
let Some(&expr_id) = self.expr_node_to_expr_id.get(&memo_node) else {
unreachable!("not found {}", memo_node)
};
let group_id = self.expr_id_to_group_id[&expr_id];
(group_id, expr_id)
}
fn infer_properties(
&self,
memo_node: MemoPlanNode<T>,
) -> Vec<Box<dyn Any + 'static + Send + Sync>> {
let child_properties = memo_node
.children
.iter()
.map(|child| self.groups[child].properties.clone())
.collect_vec();
let mut props = Vec::with_capacity(self.property_builders.len());
for (id, builder) in self.property_builders.iter().enumerate() {
let child_properties = child_properties
.iter()
.map(|x| x[id].as_ref() as &dyn std::any::Any)
.collect::<Vec<_>>();
let child_predicates = memo_node
.predicates
.iter()
.map(|x| self.pred_id_to_pred_node[x].clone())
.collect_vec();
let prop = builder.derive_any(
memo_node.typ.clone(),
&child_predicates,
child_properties.as_slice(),
);
props.push(prop);
}
props
}
fn append_expr_to_group(
&mut self,
expr_id: ExprId,
group_id: GroupId,
memo_node: MemoPlanNode<T>,
) {
trace!(event = "add_expr_to_group", group_id = %group_id, expr_id = %expr_id, memo_node = %memo_node);
if let Entry::Occupied(mut entry) = self.groups.entry(group_id) {
let group = entry.get_mut();
group.group_exprs.insert(expr_id);
return;
}
let mut group = Group {
group_exprs: HashSet::new(),
info: GroupInfo::default(),
properties: self.infer_properties(memo_node).into(),
};
group.group_exprs.insert(expr_id);
self.groups.insert(group_id, group);
self.merged_group_mapping.insert(group_id, group_id);
}
pub fn clear_winner(&mut self) {
for group in self.groups.values_mut() {
group.info.winner = Winner::Unknown;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
nodes::{PredNode, Value},
property::PropertyBuilder,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum MemoTestRelTyp {
Join,
Project,
Scan,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum MemoTestPredTyp {
List,
Expr,
TableName,
}
impl std::fmt::Display for MemoTestRelTyp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::fmt::Display for MemoTestPredTyp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl NodeType for MemoTestRelTyp {
type PredType = MemoTestPredTyp;
fn is_logical(&self) -> bool {
matches!(self, Self::Project | Self::Scan | Self::Join)
}
}
fn join(
left: impl Into<PlanNodeOrGroup<MemoTestRelTyp>>,
right: impl Into<PlanNodeOrGroup<MemoTestRelTyp>>,
cond: ArcPredNode<MemoTestRelTyp>,
) -> ArcPlanNode<MemoTestRelTyp> {
Arc::new(PlanNode {
typ: MemoTestRelTyp::Join,
children: vec![left.into(), right.into()],
predicates: vec![cond],
})
}
fn scan(table: &str) -> ArcPlanNode<MemoTestRelTyp> {
Arc::new(PlanNode {
typ: MemoTestRelTyp::Scan,
children: vec![],
predicates: vec![table_name(table)],
})
}
fn table_name(table: &str) -> ArcPredNode<MemoTestRelTyp> {
Arc::new(PredNode {
typ: MemoTestPredTyp::TableName,
children: vec![],
data: Some(Value::String(table.to_string().into())),
})
}
fn project(
input: impl Into<PlanNodeOrGroup<MemoTestRelTyp>>,
expr_list: ArcPredNode<MemoTestRelTyp>,
) -> ArcPlanNode<MemoTestRelTyp> {
Arc::new(PlanNode {
typ: MemoTestRelTyp::Project,
children: vec![input.into()],
predicates: vec![expr_list],
})
}
fn list(items: Vec<ArcPredNode<MemoTestRelTyp>>) -> ArcPredNode<MemoTestRelTyp> {
Arc::new(PredNode {
typ: MemoTestPredTyp::List,
children: items,
data: None,
})
}
fn expr(data: Value) -> ArcPredNode<MemoTestRelTyp> {
Arc::new(PredNode {
typ: MemoTestPredTyp::Expr,
children: vec![],
data: Some(data),
})
}
fn group(group_id: GroupId) -> PlanNodeOrGroup<MemoTestRelTyp> {
PlanNodeOrGroup::Group(group_id)
}
#[test]
fn add_predicate() {
let mut memo = NaiveMemo::<MemoTestRelTyp>::new(Arc::new([]));
let pred_node = list(vec![expr(Value::Int32(233))]);
let p1 = memo.add_new_pred(pred_node.clone());
let p2 = memo.add_new_pred(pred_node.clone());
assert_eq!(p1, p2);
}
#[test]
fn group_merge_1() {
let mut memo = NaiveMemo::new(Arc::new([]));
let (group_id, _) =
memo.add_new_expr(join(scan("t1"), scan("t2"), expr(Value::Bool(true))));
memo.add_expr_to_group(
join(scan("t2"), scan("t1"), expr(Value::Bool(true))).into(),
group_id,
);
assert_eq!(memo.get_group(group_id).group_exprs.len(), 2);
}
#[test]
fn group_merge_2() {
let mut memo = NaiveMemo::new(Arc::new([]));
let (group_id_1, _) = memo.add_new_expr(project(
join(scan("t1"), scan("t2"), expr(Value::Bool(true))),
list(vec![expr(Value::Int64(1))]),
));
let (group_id_2, _) = memo.add_new_expr(project(
join(scan("t1"), scan("t2"), expr(Value::Bool(true))),
list(vec![expr(Value::Int64(1))]),
));
assert_eq!(group_id_1, group_id_2);
}
#[test]
fn group_merge_3() {
let mut memo = NaiveMemo::new(Arc::new([]));
let expr1 = project(scan("t1"), list(vec![expr(Value::Int64(1))]));
let expr2 = project(scan("t1-alias"), list(vec![expr(Value::Int64(1))]));
memo.add_new_expr(expr1.clone());
memo.add_new_expr(expr2.clone());
let (group_id_expr, _) = memo.get_expr_info(scan("t1"));
memo.add_expr_to_group(scan("t1-alias").into(), group_id_expr);
let (group_1, _) = memo.get_expr_info(expr1);
let (group_2, _) = memo.get_expr_info(expr2);
assert_eq!(group_1, group_2);
}
#[test]
fn group_merge_4() {
let mut memo = NaiveMemo::new(Arc::new([]));
let expr1 = project(
project(scan("t1"), list(vec![expr(Value::Int64(1))])),
list(vec![expr(Value::Int64(2))]),
);
let expr2 = project(
project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])),
list(vec![expr(Value::Int64(2))]),
);
memo.add_new_expr(expr1.clone());
memo.add_new_expr(expr2.clone());
let (group_id_expr, _) = memo.get_expr_info(scan("t1"));
memo.add_expr_to_group(scan("t1-alias").into(), group_id_expr);
let (group_1, _) = memo.get_expr_info(expr1.clone());
let (group_2, _) = memo.get_expr_info(expr2.clone());
assert_eq!(group_1, group_2);
let (group_1, _) = memo.get_expr_info(expr1.child_rel(0));
let (group_2, _) = memo.get_expr_info(expr2.child_rel(0));
assert_eq!(group_1, group_2);
}
#[test]
fn group_merge_5() {
let mut memo = NaiveMemo::new(Arc::new([]));
let expr1 = project(
project(scan("t1"), list(vec![expr(Value::Int64(1))])),
list(vec![expr(Value::Int64(2))]),
);
let expr2 = project(
project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])),
list(vec![expr(Value::Int64(2))]),
);
let (_, expr1_id) = memo.add_new_expr(expr1.clone());
let (_, expr2_id) = memo.add_new_expr(expr2.clone());
let (scan_t1, _) = memo.get_expr_info(scan("t1"));
let pred = list(vec![expr(Value::Int64(1))]);
let proj_binding = project(group(scan_t1), pred);
let middle_proj_2 = memo.get_expr_memoed(expr2_id).children[0];
memo.add_expr_to_group(proj_binding.into(), middle_proj_2);
assert_eq!(
memo.get_expr_memoed(expr1_id),
memo.get_expr_memoed(expr2_id)
); assert_eq!(memo.get_expr_info(expr1), memo.get_expr_info(expr2));
}
struct TestPropertyBuilder;
#[derive(Clone, Debug)]
struct TestProp(Vec<String>);
impl std::fmt::Display for TestProp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl PropertyBuilder<MemoTestRelTyp> for TestPropertyBuilder {
type Prop = TestProp;
fn derive(
&self,
typ: MemoTestRelTyp,
pred: &[ArcPredNode<MemoTestRelTyp>],
children: &[&Self::Prop],
) -> Self::Prop {
match typ {
MemoTestRelTyp::Join => {
let mut a = children[0].0.clone();
let b = children[1].0.clone();
a.extend(b);
TestProp(a)
}
MemoTestRelTyp::Project => {
let preds = &pred[0].children;
TestProp(
preds
.iter()
.map(|x| x.data.as_ref().unwrap().as_i64().to_string())
.collect(),
)
}
MemoTestRelTyp::Scan => TestProp(vec!["scan_col".to_string()]),
}
}
fn property_name(&self) -> &'static str {
"test"
}
}
#[test]
fn logical_property() {
let mut memo = NaiveMemo::new(Arc::new([Box::new(TestPropertyBuilder)]));
let (group_id, _) = memo.add_new_expr(join(
scan("t1"),
project(
scan("t2"),
list(vec![expr(Value::Int64(1)), expr(Value::Int64(2))]),
),
expr(Value::Bool(true)),
));
let group = memo.get_group(group_id);
assert_eq!(group.properties.len(), 1);
assert_eq!(
group.properties[0].downcast_ref::<TestProp>().unwrap().0,
vec!["scan_col", "1", "2"]
);
}
}