lbug 0.18.0

An in-process property graph database management system built for query speed and scalability
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "common/copy_constructors.h"
#include "common/types/types.h"
#include "function/aggregate/count_star.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/comparison/vector_comparison_functions.h"
#include "main/client_context.h"
#include "planner/operator/factorization/flatten_resolver.h"
#include "planner/operator/logical_aggregate.h"
#include "planner/operator/logical_filter.h"
#include "planner/operator/logical_flatten.h"
#include "planner/operator/logical_projection.h"
#include "processor/operator/aggregate/hash_aggregate.h"
#include "processor/operator/aggregate/hash_aggregate_scan.h"
#include "processor/operator/aggregate/packed_filtered_count.h"
#include "processor/operator/aggregate/simple_aggregate.h"
#include "processor/operator/aggregate/simple_aggregate_scan.h"
#include "processor/plan_mapper.h"
#include "processor/result/result_set_descriptor.h"

using namespace lbug::binder;
using namespace lbug::common;
using namespace lbug::function;
using namespace lbug::planner;

namespace lbug {
namespace processor {

static std::vector<AggregateInfo> getAggregateInputInfos(const expression_vector& keys,
    const expression_vector& aggregates, const Schema& schema) {
    // Collect unFlat groups from
    std::unordered_set<f_group_pos> groupByGroupPosSet;
    for (auto& expression : keys) {
        groupByGroupPosSet.insert(schema.getGroupPos(*expression));
    }
    std::unordered_set<f_group_pos> unFlatAggregateGroupPosSet;
    for (auto groupPos : schema.getGroupsPosInScope()) {
        if (groupByGroupPosSet.contains(groupPos)) {
            continue;
        }
        if (schema.getGroup(groupPos)->isFlat()) {
            continue;
        }
        unFlatAggregateGroupPosSet.insert(groupPos);
    }
    std::vector<AggregateInfo> result;
    for (auto& expression : aggregates) {
        auto aggregateVectorPos = DataPos::getInvalidPos();
        if (expression->getNumChildren() != 0) { // COUNT(*) has no children
            auto child = expression->getChild(0);
            aggregateVectorPos = DataPos{schema.getExpressionPos(*child)};
        }
        std::vector<data_chunk_pos_t> multiplicityChunksPos;
        for (auto& groupPos : unFlatAggregateGroupPosSet) {
            if (groupPos != aggregateVectorPos.dataChunkPos) {
                multiplicityChunksPos.push_back(groupPos);
            }
        }
        auto aggExpr = expression->constPtrCast<AggregateFunctionExpression>();
        auto distinctAggKeyType = aggExpr->isDistinct() ?
                                      expression->getChild(0)->getDataType().copy() :
                                      LogicalType::ANY();
        result.emplace_back(aggregateVectorPos, std::move(multiplicityChunksPos),
            std::move(distinctAggKeyType));
    }
    return result;
}

static expression_vector getKeyExpressions(const expression_vector& expressions,
    const Schema& schema, bool isFlat) {
    expression_vector result;
    for (auto& expression : expressions) {
        if (schema.getGroup(schema.getGroupPos(*expression))->isFlat() == isFlat) {
            result.emplace_back(expression);
        }
    }
    return result;
}

static std::vector<AggregateFunction> getAggFunctions(const expression_vector& aggregates) {
    std::vector<AggregateFunction> aggregateFunctions;
    for (auto& expression : aggregates) {
        auto aggExpr = expression->constPtrCast<AggregateFunctionExpression>();
        aggregateFunctions.push_back(aggExpr->getFunction().copy());
    }
    return aggregateFunctions;
}

static void writeAggResultWithNullToVector(ValueVector& vector, uint64_t pos,
    AggregateState* aggregateState) {
    auto isNull = aggregateState->constCast<AggregateStateWithNull>().isNull;
    vector.setNull(pos, isNull);
    if (!isNull) {
        aggregateState->writeToVector(&vector, pos);
    }
}

static void writeAggResultWithoutNullToVector(ValueVector& vector, uint64_t pos,
    AggregateState* aggregateState) {
    vector.setNull(pos, false);
    aggregateState->writeToVector(&vector, pos);
}

static std::vector<move_agg_result_to_vector_func> getMoveAggResultToVectorFuncs(
    std::vector<AggregateFunction>& aggregateFunctions) {
    std::vector<move_agg_result_to_vector_func> moveAggResultToVectorFuncs;
    for (auto& aggregateFunction : aggregateFunctions) {
        if (aggregateFunction.needToHandleNulls) {
            moveAggResultToVectorFuncs.push_back(writeAggResultWithoutNullToVector);
        } else {
            moveAggResultToVectorFuncs.push_back(writeAggResultWithNullToVector);
        }
    }
    return moveAggResultToVectorFuncs;
}

static const LogicalOperator* unwrapFlattens(const LogicalOperator* op) {
    while (op->getOperatorType() == LogicalOperatorType::FLATTEN) {
        op = op->getChild(0).get();
    }
    return op;
}

static bool isScalarFunction(const std::shared_ptr<Expression>& expression,
    const std::string& name) {
    auto scalarFunction = dynamic_cast<const ScalarFunctionExpression*>(expression.get());
    if (scalarFunction == nullptr) {
        return false;
    }
    return scalarFunction->getFunction().name == name;
}

static bool isInt64Literal(const std::shared_ptr<Expression>& expression, int64_t value) {
    if (expression->expressionType != ExpressionType::LITERAL ||
        expression->getDataType().getLogicalTypeID() != LogicalTypeID::INT64) {
        return false;
    }
    return expression->constCast<LiteralExpression>().getValue().getValue<int64_t>() == value;
}

static std::optional<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>>
tryGetModuloSumPredicateInputs(const std::shared_ptr<Expression>& predicate) {
    if (!isScalarFunction(predicate, function::EqualsFunction::name) ||
        predicate->getNumChildren() != 2) {
        return std::nullopt;
    }
    auto modulo = predicate->getChild(0);
    auto zero = predicate->getChild(1);
    if (isInt64Literal(modulo, 0)) {
        std::swap(modulo, zero);
    }
    if (!isInt64Literal(zero, 0) || !isScalarFunction(modulo, function::ModuloFunction::name) ||
        modulo->getNumChildren() != 2 || !isInt64Literal(modulo->getChild(1), 10)) {
        return std::nullopt;
    }
    auto add = modulo->getChild(0);
    if (!isScalarFunction(add, function::AddFunction::name) || add->getNumChildren() != 2) {
        return std::nullopt;
    }
    return std::make_pair(add->getChild(0), add->getChild(1));
}

static std::unique_ptr<PhysicalOperator> tryMapPackedFilteredCount(PlanMapper& mapper,
    main::ClientContext* clientContext, const LogicalAggregate& agg) {
    if (!clientContext->getClientConfig()->enablePackedPathExtend || agg.getKeys().size() != 1 ||
        agg.getAggregates().size() != 1 || agg.getDependentKeys().size() != 0) {
        return nullptr;
    }
    const auto aggregates = agg.getAggregates();
    const auto& aggregate = aggregates[0];
    if (aggregate->getNumChildren() != 0) {
        return nullptr;
    }
    auto aggregateExpr = aggregate->constPtrCast<AggregateFunctionExpression>();
    if (aggregateExpr->getFunction().name != function::CountStarFunction::name) {
        return nullptr;
    }
    const auto* projection = agg.getChild(0).get();
    if (projection->getOperatorType() != LogicalOperatorType::PROJECTION) {
        return nullptr;
    }
    const auto& logicalProjection = projection->constCast<LogicalProjection>();
    if (logicalProjection.getExpressionsToProject().size() != 1) {
        return nullptr;
    }
    const auto* filter = projection->getChild(0).get();
    if (filter->getOperatorType() != LogicalOperatorType::FILTER) {
        return nullptr;
    }
    const auto& logicalFilter = filter->constCast<LogicalFilter>();
    auto predicateInputs = tryGetModuloSumPredicateInputs(logicalFilter.getPredicate());
    if (!predicateInputs.has_value()) {
        return nullptr;
    }
    const auto* packedChild = unwrapFlattens(filter->getChild(0).get());
    auto* packedChildSchema = packedChild->getSchema();
    auto analyzer = GroupDependencyAnalyzer(true, *packedChildSchema);
    analyzer.visit(logicalFilter.getPredicate());
    const auto dependentGroups = analyzer.getDependentGroups();
    if (dependentGroups.size() != 2) {
        return nullptr;
    }
    for (auto groupPos : dependentGroups) {
        if (packedChildSchema->getGroup(groupPos)->isFlat()) {
            return nullptr;
        }
    }
    const auto keys = agg.getKeys();
    const auto& key = keys[0];
    if (key->getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
        aggregate->getDataType().getLogicalTypeID() != LogicalTypeID::INT64) {
        return nullptr;
    }
    const auto keyGroupPos = packedChildSchema->getGroupPos(*key);
    std::vector<data_chunk_pos_t> multiplicityChunks;
    for (auto groupPos : packedChildSchema->getGroupsPosInScope()) {
        if (groupPos == keyGroupPos || dependentGroups.contains(groupPos) ||
            packedChildSchema->getGroup(groupPos)->isFlat()) {
            continue;
        }
        multiplicityChunks.push_back(groupPos);
    }
    std::vector<data_chunk_pos_t> dependentGroupsVector{dependentGroups.begin(),
        dependentGroups.end()};
    auto sharedState = std::make_shared<PackedFilteredCountSharedState>();
    auto info = PackedFilteredCountInfo{DataPos{packedChildSchema->getExpressionPos(*key)},
        DataPos{agg.getSchema()->getExpressionPos(*key)},
        DataPos{agg.getSchema()->getExpressionPos(*aggregate)},
        DataPos{packedChildSchema->getExpressionPos(*predicateInputs->first)},
        DataPos{packedChildSchema->getExpressionPos(*predicateInputs->second)},
        dependentGroupsVector[0], dependentGroupsVector[1], std::move(multiplicityChunks)};
    auto sink = std::make_unique<PackedFilteredCount>(sharedState, info,
        mapper.mapOperator(packedChild), mapper.getOperatorID(),
        std::make_unique<PackedFilteredCountPrintInfo>(logicalFilter.getPredicate(),
            agg.getKeys()));
    sink->setDescriptor(std::make_unique<ResultSetDescriptor>(packedChildSchema));
    auto scan = std::make_unique<PackedFilteredCountScan>(sharedState, info.groupKeyOutputPos,
        info.countOutputPos, std::move(sink), mapper.getOperatorID(),
        std::make_unique<PackedFilteredCountPrintInfo>(logicalFilter.getPredicate(),
            agg.getKeys()));
    return scan;
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapAggregate(const LogicalOperator* logicalOperator) {
    auto& agg = logicalOperator->constCast<LogicalAggregate>();
    if (auto packedFilteredCount = tryMapPackedFilteredCount(*this, clientContext, agg)) {
        return packedFilteredCount;
    }
    auto aggregates = agg.getAggregates();
    auto outSchema = agg.getSchema();
    auto child = agg.getChild(0).get();
    auto inSchema = child->getSchema();
    auto prevOperator = mapOperator(child);
    if (agg.hasKeys()) {
        return createHashAggregate(agg.getKeys(), agg.getDependentKeys(), aggregates, inSchema,
            outSchema, std::move(prevOperator));
    }
    auto aggFunctions = getAggFunctions(aggregates);
    auto aggOutputPos = getDataPos(aggregates, *outSchema);
    auto aggregateInputInfos = getAggregateInputInfos(agg.getAllKeys(), aggregates, *inSchema);
    auto sharedState =
        make_shared<SimpleAggregateSharedState>(clientContext, aggFunctions, aggregateInputInfos);
    auto printInfo = std::make_unique<SimpleAggregatePrintInfo>(aggregates);
    auto aggregate = make_unique<SimpleAggregate>(sharedState, std::move(aggFunctions),
        copyVector(aggregateInputInfos), std::move(prevOperator), getOperatorID(),
        printInfo->copy());
    aggregate->setDescriptor(std::make_unique<ResultSetDescriptor>(inSchema));
    auto finalizer = std::make_unique<SimpleAggregateFinalize>(sharedState,
        std::move(aggregateInputInfos), getOperatorID(), printInfo->copy());
    finalizer->addChild(std::move(aggregate));
    aggFunctions = getAggFunctions(aggregates);
    auto scan = std::make_unique<SimpleAggregateScan>(sharedState,
        AggregateScanInfo{std::move(aggOutputPos), getMoveAggResultToVectorFuncs(aggFunctions)},
        getOperatorID(), printInfo->copy());
    scan->addChild(std::move(finalizer));
    return scan;
}

static FactorizedTableSchema getFactorizedTableSchema(const expression_vector& flatKeys,
    const expression_vector& unFlatKeys, const expression_vector& payloads,
    const std::vector<AggregateFunction>& aggregateFunctions) {
    auto isUnFlat = false;
    auto groupID = 0u;
    auto tableSchema = FactorizedTableSchema();
    for (auto& flatKey : flatKeys) {
        auto size = LogicalTypeUtils::getRowLayoutSize(flatKey->dataType);
        tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size));
    }
    for (auto& unFlatKey : unFlatKeys) {
        auto size = LogicalTypeUtils::getRowLayoutSize(unFlatKey->dataType);
        tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size));
    }
    for (auto& payload : payloads) {
        auto size = LogicalTypeUtils::getRowLayoutSize(payload->dataType);
        tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size));
    }
    for (auto& aggregateFunc : aggregateFunctions) {
        tableSchema.appendColumn(
            ColumnSchema(isUnFlat, groupID, aggregateFunc.getAggregateStateSize()));
    }
    tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, sizeof(hash_t)));
    return tableSchema;
}

std::unique_ptr<PhysicalOperator> PlanMapper::createDistinctHashAggregate(
    const expression_vector& keys, const expression_vector& payloads, Schema* inSchema,
    Schema* outSchema, std::unique_ptr<PhysicalOperator> prevOperator) {
    return createHashAggregate(keys, payloads, expression_vector{} /* aggregates */, inSchema,
        outSchema, std::move(prevOperator));
}

// Payloads are also group by keys except that they are functional dependent on keys so we don't
// need to hash or compare payloads.
std::unique_ptr<PhysicalOperator> PlanMapper::createHashAggregate(const expression_vector& keys,
    const expression_vector& payloads, const expression_vector& aggregates, Schema* inSchema,
    Schema* outSchema, std::unique_ptr<PhysicalOperator> prevOperator) {
    // Create hash aggregate
    auto aggFunctions = getAggFunctions(aggregates);
    expression_vector allKeys;
    allKeys.insert(allKeys.end(), keys.begin(), keys.end());
    allKeys.insert(allKeys.end(), payloads.begin(), payloads.end());
    auto aggregateInputInfos = getAggregateInputInfos(allKeys, aggregates, *inSchema);
    auto flatKeys = getKeyExpressions(keys, *inSchema, true /* isFlat */);
    auto unFlatKeys = getKeyExpressions(keys, *inSchema, false /* isFlat */);
    std::vector<LogicalType> keyTypes, payloadTypes;
    for (auto& key : flatKeys) {
        keyTypes.push_back(key->getDataType().copy());
    }
    for (auto& key : unFlatKeys) {
        keyTypes.push_back(key->getDataType().copy());
    }
    for (auto& payload : payloads) {
        payloadTypes.push_back(payload->getDataType().copy());
    }
    auto tableSchema = getFactorizedTableSchema(flatKeys, unFlatKeys, payloads, aggFunctions);
    HashAggregateInfo aggregateInfo{getDataPos(flatKeys, *inSchema),
        getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)};

    auto sharedState =
        std::make_shared<HashAggregateSharedState>(clientContext, std::move(aggregateInfo),
            aggFunctions, aggregateInputInfos, std::move(keyTypes), std::move(payloadTypes));
    auto printInfo = std::make_unique<HashAggregatePrintInfo>(allKeys, aggregates);
    auto aggregate = make_unique<HashAggregate>(sharedState, std::move(aggFunctions),
        std::move(aggregateInputInfos), std::move(prevOperator), getOperatorID(),
        printInfo->copy());
    aggregate->setDescriptor(std::make_unique<ResultSetDescriptor>(inSchema));
    // Create AggScan.
    expression_vector outputExpressions;
    outputExpressions.insert(outputExpressions.end(), flatKeys.begin(), flatKeys.end());
    outputExpressions.insert(outputExpressions.end(), unFlatKeys.begin(), unFlatKeys.end());
    outputExpressions.insert(outputExpressions.end(), payloads.begin(), payloads.end());
    auto aggOutputPos = getDataPos(aggregates, *outSchema);
    auto finalizer =
        std::make_unique<HashAggregateFinalize>(sharedState, getOperatorID(), printInfo->copy());
    finalizer->addChild(std::move(aggregate));
    aggFunctions = getAggFunctions(aggregates);
    auto scan =
        std::make_unique<HashAggregateScan>(sharedState, getDataPos(outputExpressions, *outSchema),
            AggregateScanInfo{std::move(aggOutputPos), getMoveAggResultToVectorFuncs(aggFunctions)},
            getOperatorID(), printInfo->copy());
    scan->addChild(std::move(finalizer));
    return scan;
}

} // namespace processor
} // namespace lbug