locustdb 0.3.4

Embeddable high-performance analytics database.
Documentation
use crate::QueryError;
use crate::engine::*;
use crate::mem_store::*;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::result::Result;

use self::query_plan::prepare;
use self::QueryPlan::*;

#[derive(Default)]
pub struct QueryPlanner {
    pub operations: Vec<QueryPlan>,
    pub buffer_to_operation: Vec<Option<usize>>,
    pub cache: HashMap<[u8; 16], Vec<TypedBufferRef>>,
    checkpoint: usize,
    cache_checkpoint: HashMap<[u8; 16], Vec<TypedBufferRef>>,
    pub buffer_provider: BufferProvider,
}

impl QueryPlanner {
    pub fn prepare<'a>(&mut self, mut constant_vecs: Vec<BoxedData<'a>>) -> Result<QueryExecutor<'a>, QueryError> {
        self.perform_rewrites();

        let mut result = QueryExecutor::default();
        result.set_buffer_count(self.buffer_provider.buffer_count());
        for operation in &self.operations {
            prepare(operation.clone(), &mut constant_vecs, &mut result)?;
        }
        Ok(result)
    }

    pub fn checkpoint(&mut self) {
        self.checkpoint = self.operations.len();
        self.cache_checkpoint = self.cache.clone();
    }

    pub fn reset(&mut self) {
        self.operations.truncate(self.checkpoint);
        std::mem::swap(&mut self.cache, &mut self.cache_checkpoint);
    }

    pub fn resolve(&self, buffer: &TypedBufferRef) -> &QueryPlan {
        let op_index = self.buffer_to_operation[buffer.buffer.i]
            .unwrap_or_else(|| panic!("No entry found for {:?}", buffer));
        &self.operations[op_index]
    }

    pub fn enable_common_subexpression_elimination(&self) -> bool { true }


    fn perform_rewrites(&mut self) {
        for i in 0..self.operations.len() {
            match propagate_nullability(&self.operations[i], &mut self.buffer_provider) {
                Rewrite::ReplaceWith(ops) => {
                    trace!("Replacing {:#?} with {:#?}", self.operations[i], ops);
                    self.operations[i] = ops[0].clone();
                    for op in ops.into_iter().skip(1) {
                        self.operations.push(op);
                    }
                }
                Rewrite::None => {}
            }
        }
    }
}

enum Rewrite {
    None,
    ReplaceWith(Vec<QueryPlan>),
}

fn propagate_nullability(operation: &QueryPlan, bp: &mut BufferProvider) -> Rewrite {
    match *operation {
        Cast { input, casted } if input.is_nullable() && casted.tag != EncodingType::Val => {
            let casted_non_nullable = bp.named_buffer("casted_non_nullable", casted.tag.non_nullable());
            let cast = Cast {
                input: input.forget_nullability(),
                casted: casted_non_nullable,
            };
            let nullable = PropagateNullability {
                nullable: input,
                data: casted_non_nullable,
                nullable_data: casted,
            };
            Rewrite::ReplaceWith(vec![cast, nullable])
        }
        Add { lhs, rhs, sum } if sum.is_nullable() => {
            let sum_non_null = bp.named_buffer("sum_non_null", sum.tag.non_nullable());
            let mut ops = vec![Add {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                sum: sum_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, sum_non_null, sum));
            Rewrite::ReplaceWith(ops)
        }
        CheckedAdd { lhs, rhs, sum } if sum.is_nullable() => {
            let (present, plan) = combine_nulls2(bp, lhs, rhs);
            let ops = vec![
                plan,
                NullableCheckedAdd {
                    lhs: lhs.forget_nullability(),
                    rhs: rhs.forget_nullability(),
                    present,
                    sum: sum.nullable_i64().unwrap(),
                }
            ];
            Rewrite::ReplaceWith(ops)
        }
        Subtract { lhs, rhs, difference } if difference.is_nullable() => {
            let difference_non_null = bp.named_buffer("difference_non_null", difference.tag.non_nullable());
            let mut ops = vec![Subtract {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                difference: difference_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, difference_non_null, difference));
            Rewrite::ReplaceWith(ops)
        }
        CheckedSubtract { lhs, rhs, difference } if difference.is_nullable() => {
            let (present, plan) = combine_nulls2(bp, lhs, rhs);
            let ops = vec![
                plan,
                NullableCheckedSubtract {
                    lhs: lhs.forget_nullability(),
                    rhs: rhs.forget_nullability(),
                    present,
                    difference: difference.nullable_i64().unwrap(),
                }
            ];
            Rewrite::ReplaceWith(ops)
        }
        Multiply { lhs, rhs, product } if product.is_nullable() => {
            let product_non_null = bp.named_buffer("product_non_null", product.tag.non_nullable());
            let mut ops = vec![Multiply {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                product: product_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, product_non_null, product));
            Rewrite::ReplaceWith(ops)
        }
        CheckedMultiply { lhs, rhs, product } if product.is_nullable() => {
            let (present, plan) = combine_nulls2(bp, lhs, rhs);
            let ops = vec![
                plan,
                NullableCheckedMultiply {
                    lhs: lhs.forget_nullability(),
                    rhs: rhs.forget_nullability(),
                    present,
                    product: product.nullable_i64().unwrap(),
                }
            ];
            Rewrite::ReplaceWith(ops)
        }
        Divide { lhs, rhs, division } if division.is_nullable() => {
            let division_non_null = bp.named_buffer("division_non_null", division.tag.non_nullable());
            let mut ops = vec![Divide {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                division: division_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, division_non_null, division));
            Rewrite::ReplaceWith(ops)
        }
        CheckedDivide { lhs, rhs, division } if division.is_nullable() => {
            let (present, plan) = combine_nulls2(bp, lhs, rhs);
            let ops = vec![
                plan,
                NullableCheckedDivide {
                    lhs: lhs.forget_nullability(),
                    rhs: rhs.forget_nullability(),
                    present,
                    division: division.nullable_i64().unwrap(),
                }
            ];
            Rewrite::ReplaceWith(ops)
        }
        Modulo { lhs, rhs, modulo } if modulo.is_nullable() => {
            let modulo_non_null = bp.named_buffer("modulo_non_null", modulo.tag.non_nullable());
            let mut ops = vec![Modulo {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                modulo: modulo_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, modulo_non_null, modulo));
            Rewrite::ReplaceWith(ops)
        }
        CheckedModulo { lhs, rhs, modulo } if modulo.is_nullable() => {
            let (present, plan) = combine_nulls2(bp, lhs, rhs);
            let ops = vec![
                plan,
                NullableCheckedModulo {
                    lhs: lhs.forget_nullability(),
                    rhs: rhs.forget_nullability(),
                    present,
                    modulo: modulo.nullable_i64().unwrap(),
                }
            ];
            Rewrite::ReplaceWith(ops)
        }
        And { lhs, rhs, and } if and.is_nullable() => {
            let and_non_null = bp.named_buffer("and_non_null", and.tag.non_nullable());
            let mut ops = vec![And {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                and: and_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, and_non_null, and));
            Rewrite::ReplaceWith(ops)
        }
        Or { lhs, rhs, or } if or.is_nullable() => {
            let or_non_null = bp.named_buffer("or_non_null", or.tag.non_nullable());
            let mut ops = vec![Or {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                or: or_non_null,
            }];
            ops.extend(combine_nulls(bp, lhs, rhs, or_non_null, or));
            Rewrite::ReplaceWith(ops)
        }
        LessThan { lhs, rhs, less_than } if less_than.is_nullable() => {
            let less_than_non_null = bp.named_buffer("less_than_non_null", less_than.tag.non_nullable());
            let less_than_op = LessThan {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                less_than: less_than_non_null,
            };
            let mut ops = combine_nulls(bp, lhs, rhs, less_than_non_null, less_than);
            ops.push(less_than_op);
            Rewrite::ReplaceWith(ops)
        }
        LessThanEquals { lhs, rhs, less_than_equals } if less_than_equals.is_nullable() => {
            let less_than_equals_non_null = bp.named_buffer("less_than_equals_non_null", less_than_equals.tag.non_nullable());
            let less_than_equals_op = LessThanEquals {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                less_than_equals: less_than_equals_non_null,
            };
            let mut ops = combine_nulls(bp, lhs, rhs, less_than_equals_non_null, less_than_equals);
            ops.push(less_than_equals_op);
            Rewrite::ReplaceWith(ops)
        }
        Equals { lhs, rhs, equals } if equals.is_nullable() => {
            let equals_non_null = bp.named_buffer("equals_non_null", equals.tag.non_nullable());
            let equals_op = Equals {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                equals: equals_non_null,
            };
            let mut ops = combine_nulls(bp, lhs, rhs, equals_non_null, equals);
            ops.push(equals_op);
            Rewrite::ReplaceWith(ops)
        }
        NotEquals { lhs, rhs, not_equals } if not_equals.is_nullable() => {
            let not_equals_non_null = bp.named_buffer("not_equals_non_null", not_equals.tag.non_nullable());
            let not_equals_op = NotEquals {
                lhs: lhs.forget_nullability(),
                rhs: rhs.forget_nullability(),
                not_equals: not_equals_non_null,
            };
            let mut ops = combine_nulls(bp, lhs, rhs, not_equals_non_null, not_equals);
            ops.push(not_equals_op);
            Rewrite::ReplaceWith(ops)
        }
        MergeKeep { take_left, lhs, rhs, merged } if lhs.is_nullable() != rhs.is_nullable() => {
            let mut ops = Vec::with_capacity(2);
            let lhs = if lhs.is_nullable() { lhs } else {
                let lhs_nullable = bp.named_buffer("lhs_nullable", lhs.tag.nullable());
                ops.push(MakeNullable { data: lhs, present: bp.buffer_u8("present"), nullable: lhs_nullable });
                lhs_nullable
            };
            let rhs = if rhs.is_nullable() { rhs } else {
                let rhs_nullable = bp.named_buffer("rhs_nullable", rhs.tag.nullable());
                ops.push(MakeNullable { data: rhs, present: bp.buffer_u8("present"), nullable: rhs_nullable });
                rhs_nullable
            };
            ops.push(MergeKeep { take_left, lhs, rhs, merged });
            Rewrite::ReplaceWith(ops)
        }
        DictLookup { indices, offset_len, backing_store, decoded }if indices.is_nullable() => {
            let decoded_non_null = bp.named_buffer("decoded_non_null", decoded.tag.non_nullable());
            Rewrite::ReplaceWith(vec![
                DictLookup {
                    indices: indices.forget_nullability(),
                    offset_len,
                    backing_store,
                    decoded: decoded_non_null,
                },
                PropagateNullability {
                    nullable: indices,
                    data: decoded_non_null,
                    nullable_data: decoded,
                },
            ])
        }
        _ => Rewrite::None,
    }
}

fn combine_nulls(bp: &mut BufferProvider,
                 lhs: TypedBufferRef,
                 rhs: TypedBufferRef,
                 data: TypedBufferRef,
                 nullable_data: TypedBufferRef) -> Vec<QueryPlan> {
    if lhs.is_nullable() && rhs.is_nullable() {
        let combined_null_map = bp.buffer_u8("combined_null_map");
        vec![
            CombineNullMaps {
                lhs,
                rhs,
                present: combined_null_map,
            },
            AssembleNullable {
                data,
                present: combined_null_map,
                nullable: nullable_data,
            }
        ]
    } else {
        vec![
            PropagateNullability {
                nullable: if lhs.is_nullable() { lhs } else { rhs },
                data,
                nullable_data,
            }]
    }
}

fn combine_nulls2(bp: &mut BufferProvider,
                  lhs: TypedBufferRef,
                  rhs: TypedBufferRef) -> (BufferRef<u8>, QueryPlan) {
    let combined_null_map = bp.buffer_u8("combined_null_map");
    let plan = if lhs.is_nullable() && rhs.is_nullable() {
        CombineNullMaps {
            lhs,
            rhs,
            present: combined_null_map,
        }
    } else {
        GetNullMap {
            nullable: if lhs.is_nullable() { lhs } else { rhs },
            present: combined_null_map,
        }
    };
    (combined_null_map, plan)
}

#[derive(Default)]
pub struct BufferProvider {
    buffer_count: usize,
    shared_buffers: HashMap<&'static str, TypedBufferRef>,
}

impl BufferProvider {
    pub fn named_buffer(&mut self, name: &'static str, tag: EncodingType) -> TypedBufferRef {
        let buffer = TypedBufferRef::new(BufferRef { i: self.buffer_count, name, t: PhantomData }, tag);
        self.buffer_count += 1;
        buffer
    }

    pub fn buffer_str<'a>(&mut self, name: &'static str) -> BufferRef<&'a str> {
        self.named_buffer(name, EncodingType::Str).str().unwrap()
    }

    pub fn buffer_usize(&mut self, name: &'static str) -> BufferRef<usize> {
        self.named_buffer(name, EncodingType::USize).usize().unwrap()
    }

    pub fn buffer_i64(&mut self, name: &'static str) -> BufferRef<i64> {
        self.named_buffer(name, EncodingType::I64).i64().unwrap()
    }

    pub fn buffer_u32(&mut self, name: &'static str) -> BufferRef<u32> {
        self.named_buffer(name, EncodingType::U32).u32().unwrap()
    }

    pub fn buffer_u8(&mut self, name: &'static str) -> BufferRef<u8> {
        self.named_buffer(name, EncodingType::U8).u8().unwrap()
    }

    pub fn nullable_buffer_i64(&mut self, name: &'static str) -> BufferRef<Nullable<i64>> {
        self.named_buffer(name, EncodingType::NullableI64).nullable_i64().unwrap()
    }

    pub fn buffer_val<'a>(&mut self, name: &'static str) -> BufferRef<Val<'a>> {
        self.named_buffer(name, EncodingType::Val).val().unwrap()
    }

    pub fn buffer_val_rows<'a>(&mut self, name: &'static str) -> BufferRef<ValRows<'a>> {
        self.named_buffer(name, EncodingType::ValRows).val_rows().unwrap()
    }

    pub fn buffer_scalar_i64(&mut self, name: &'static str) -> BufferRef<Scalar<i64>> {
        self.named_buffer(name, EncodingType::ScalarI64).scalar_i64().unwrap()
    }

    pub fn buffer_scalar_str<'a>(&mut self, name: &'static str) -> BufferRef<Scalar<&'a str>> {
        self.named_buffer(name, EncodingType::ScalarStr).scalar_str().unwrap()
    }

    pub fn buffer_scalar_string(&mut self, name: &'static str) -> BufferRef<Scalar<String>> {
        self.named_buffer(name, EncodingType::ScalarString).scalar_string().unwrap()
    }

    pub fn buffer_merge_op(&mut self, name: &'static str) -> BufferRef<MergeOp> {
        self.named_buffer(name, EncodingType::MergeOp).merge_op().unwrap()
    }

    pub fn buffer_premerge(&mut self, name: &'static str) -> BufferRef<Premerge> {
        self.named_buffer(name, EncodingType::Premerge).premerge().unwrap()
    }

    pub fn shared_buffer(&mut self, name: &'static str, tag: EncodingType) -> TypedBufferRef {
        if self.shared_buffers.get(name).is_none() {
            let buffer = self.named_buffer(name, tag);
            self.shared_buffers.insert(name, buffer);
        }
        self.shared_buffers[name]
    }

    pub fn buffer_count(&self) -> usize { self.buffer_count }
}