#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "common/copy_constructors.h"
#include "common/types/types.h"
#include "function/aggregate/count_star.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/comparison/vector_comparison_functions.h"
#include "main/client_context.h"
#include "planner/operator/factorization/flatten_resolver.h"
#include "planner/operator/logical_aggregate.h"
#include "planner/operator/logical_filter.h"
#include "planner/operator/logical_flatten.h"
#include "planner/operator/logical_projection.h"
#include "processor/operator/aggregate/hash_aggregate.h"
#include "processor/operator/aggregate/hash_aggregate_scan.h"
#include "processor/operator/aggregate/packed_filtered_count.h"
#include "processor/operator/aggregate/simple_aggregate.h"
#include "processor/operator/aggregate/simple_aggregate_scan.h"
#include "processor/plan_mapper.h"
#include "processor/result/result_set_descriptor.h"
using namespace lbug::binder;
using namespace lbug::common;
using namespace lbug::function;
using namespace lbug::planner;
namespace lbug {
namespace processor {
static std::vector<AggregateInfo> getAggregateInputInfos(const expression_vector& keys,
const expression_vector& aggregates, const Schema& schema) {
std::unordered_set<f_group_pos> groupByGroupPosSet;
for (auto& expression : keys) {
groupByGroupPosSet.insert(schema.getGroupPos(*expression));
}
std::unordered_set<f_group_pos> unFlatAggregateGroupPosSet;
for (auto groupPos : schema.getGroupsPosInScope()) {
if (groupByGroupPosSet.contains(groupPos)) {
continue;
}
if (schema.getGroup(groupPos)->isFlat()) {
continue;
}
unFlatAggregateGroupPosSet.insert(groupPos);
}
std::vector<AggregateInfo> result;
for (auto& expression : aggregates) {
auto aggregateVectorPos = DataPos::getInvalidPos();
if (expression->getNumChildren() != 0) { auto child = expression->getChild(0);
aggregateVectorPos = DataPos{schema.getExpressionPos(*child)};
}
std::vector<data_chunk_pos_t> multiplicityChunksPos;
for (auto& groupPos : unFlatAggregateGroupPosSet) {
if (groupPos != aggregateVectorPos.dataChunkPos) {
multiplicityChunksPos.push_back(groupPos);
}
}
auto aggExpr = expression->constPtrCast<AggregateFunctionExpression>();
auto distinctAggKeyType = aggExpr->isDistinct() ?
expression->getChild(0)->getDataType().copy() :
LogicalType::ANY();
result.emplace_back(aggregateVectorPos, std::move(multiplicityChunksPos),
std::move(distinctAggKeyType));
}
return result;
}
static expression_vector getKeyExpressions(const expression_vector& expressions,
const Schema& schema, bool isFlat) {
expression_vector result;
for (auto& expression : expressions) {
if (schema.getGroup(schema.getGroupPos(*expression))->isFlat() == isFlat) {
result.emplace_back(expression);
}
}
return result;
}
static std::vector<AggregateFunction> getAggFunctions(const expression_vector& aggregates) {
std::vector<AggregateFunction> aggregateFunctions;
for (auto& expression : aggregates) {
auto aggExpr = expression->constPtrCast<AggregateFunctionExpression>();
aggregateFunctions.push_back(aggExpr->getFunction().copy());
}
return aggregateFunctions;
}
static void writeAggResultWithNullToVector(ValueVector& vector, uint64_t pos,
AggregateState* aggregateState) {
auto isNull = aggregateState->constCast<AggregateStateWithNull>().isNull;
vector.setNull(pos, isNull);
if (!isNull) {
aggregateState->writeToVector(&vector, pos);
}
}
static void writeAggResultWithoutNullToVector(ValueVector& vector, uint64_t pos,
AggregateState* aggregateState) {
vector.setNull(pos, false);
aggregateState->writeToVector(&vector, pos);
}
static std::vector<move_agg_result_to_vector_func> getMoveAggResultToVectorFuncs(
std::vector<AggregateFunction>& aggregateFunctions) {
std::vector<move_agg_result_to_vector_func> moveAggResultToVectorFuncs;
for (auto& aggregateFunction : aggregateFunctions) {
if (aggregateFunction.needToHandleNulls) {
moveAggResultToVectorFuncs.push_back(writeAggResultWithoutNullToVector);
} else {
moveAggResultToVectorFuncs.push_back(writeAggResultWithNullToVector);
}
}
return moveAggResultToVectorFuncs;
}
static const LogicalOperator* unwrapFlattens(const LogicalOperator* op) {
while (op->getOperatorType() == LogicalOperatorType::FLATTEN) {
op = op->getChild(0).get();
}
return op;
}
static bool isScalarFunction(const std::shared_ptr<Expression>& expression,
const std::string& name) {
auto scalarFunction = dynamic_cast<const ScalarFunctionExpression*>(expression.get());
if (scalarFunction == nullptr) {
return false;
}
return scalarFunction->getFunction().name == name;
}
static bool isInt64Literal(const std::shared_ptr<Expression>& expression, int64_t value) {
if (expression->expressionType != ExpressionType::LITERAL ||
expression->getDataType().getLogicalTypeID() != LogicalTypeID::INT64) {
return false;
}
return expression->constCast<LiteralExpression>().getValue().getValue<int64_t>() == value;
}
static std::optional<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>>
tryGetModuloSumPredicateInputs(const std::shared_ptr<Expression>& predicate) {
if (!isScalarFunction(predicate, function::EqualsFunction::name) ||
predicate->getNumChildren() != 2) {
return std::nullopt;
}
auto modulo = predicate->getChild(0);
auto zero = predicate->getChild(1);
if (isInt64Literal(modulo, 0)) {
std::swap(modulo, zero);
}
if (!isInt64Literal(zero, 0) || !isScalarFunction(modulo, function::ModuloFunction::name) ||
modulo->getNumChildren() != 2 || !isInt64Literal(modulo->getChild(1), 10)) {
return std::nullopt;
}
auto add = modulo->getChild(0);
if (!isScalarFunction(add, function::AddFunction::name) || add->getNumChildren() != 2) {
return std::nullopt;
}
return std::make_pair(add->getChild(0), add->getChild(1));
}
static std::unique_ptr<PhysicalOperator> tryMapPackedFilteredCount(PlanMapper& mapper,
main::ClientContext* clientContext, const LogicalAggregate& agg) {
if (!clientContext->getClientConfig()->enablePackedPathExtend || agg.getKeys().size() != 1 ||
agg.getAggregates().size() != 1 || agg.getDependentKeys().size() != 0) {
return nullptr;
}
const auto aggregates = agg.getAggregates();
const auto& aggregate = aggregates[0];
if (aggregate->getNumChildren() != 0) {
return nullptr;
}
auto aggregateExpr = aggregate->constPtrCast<AggregateFunctionExpression>();
if (aggregateExpr->getFunction().name != function::CountStarFunction::name) {
return nullptr;
}
const auto* projection = agg.getChild(0).get();
if (projection->getOperatorType() != LogicalOperatorType::PROJECTION) {
return nullptr;
}
const auto& logicalProjection = projection->constCast<LogicalProjection>();
if (logicalProjection.getExpressionsToProject().size() != 1) {
return nullptr;
}
const auto* filter = projection->getChild(0).get();
if (filter->getOperatorType() != LogicalOperatorType::FILTER) {
return nullptr;
}
const auto& logicalFilter = filter->constCast<LogicalFilter>();
auto predicateInputs = tryGetModuloSumPredicateInputs(logicalFilter.getPredicate());
if (!predicateInputs.has_value()) {
return nullptr;
}
const auto* packedChild = unwrapFlattens(filter->getChild(0).get());
auto* packedChildSchema = packedChild->getSchema();
auto analyzer = GroupDependencyAnalyzer(true, *packedChildSchema);
analyzer.visit(logicalFilter.getPredicate());
const auto dependentGroups = analyzer.getDependentGroups();
if (dependentGroups.size() != 2) {
return nullptr;
}
for (auto groupPos : dependentGroups) {
if (packedChildSchema->getGroup(groupPos)->isFlat()) {
return nullptr;
}
}
const auto keys = agg.getKeys();
const auto& key = keys[0];
if (key->getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
aggregate->getDataType().getLogicalTypeID() != LogicalTypeID::INT64) {
return nullptr;
}
const auto keyGroupPos = packedChildSchema->getGroupPos(*key);
std::vector<data_chunk_pos_t> multiplicityChunks;
for (auto groupPos : packedChildSchema->getGroupsPosInScope()) {
if (groupPos == keyGroupPos || dependentGroups.contains(groupPos) ||
packedChildSchema->getGroup(groupPos)->isFlat()) {
continue;
}
multiplicityChunks.push_back(groupPos);
}
std::vector<data_chunk_pos_t> dependentGroupsVector{dependentGroups.begin(),
dependentGroups.end()};
auto sharedState = std::make_shared<PackedFilteredCountSharedState>();
auto info = PackedFilteredCountInfo{DataPos{packedChildSchema->getExpressionPos(*key)},
DataPos{agg.getSchema()->getExpressionPos(*key)},
DataPos{agg.getSchema()->getExpressionPos(*aggregate)},
DataPos{packedChildSchema->getExpressionPos(*predicateInputs->first)},
DataPos{packedChildSchema->getExpressionPos(*predicateInputs->second)},
dependentGroupsVector[0], dependentGroupsVector[1], std::move(multiplicityChunks)};
auto sink = std::make_unique<PackedFilteredCount>(sharedState, info,
mapper.mapOperator(packedChild), mapper.getOperatorID(),
std::make_unique<PackedFilteredCountPrintInfo>(logicalFilter.getPredicate(),
agg.getKeys()));
sink->setDescriptor(std::make_unique<ResultSetDescriptor>(packedChildSchema));
auto scan = std::make_unique<PackedFilteredCountScan>(sharedState, info.groupKeyOutputPos,
info.countOutputPos, std::move(sink), mapper.getOperatorID(),
std::make_unique<PackedFilteredCountPrintInfo>(logicalFilter.getPredicate(),
agg.getKeys()));
return scan;
}
std::unique_ptr<PhysicalOperator> PlanMapper::mapAggregate(const LogicalOperator* logicalOperator) {
auto& agg = logicalOperator->constCast<LogicalAggregate>();
if (auto packedFilteredCount = tryMapPackedFilteredCount(*this, clientContext, agg)) {
return packedFilteredCount;
}
auto aggregates = agg.getAggregates();
auto outSchema = agg.getSchema();
auto child = agg.getChild(0).get();
auto inSchema = child->getSchema();
auto prevOperator = mapOperator(child);
if (agg.hasKeys()) {
return createHashAggregate(agg.getKeys(), agg.getDependentKeys(), aggregates, inSchema,
outSchema, std::move(prevOperator));
}
auto aggFunctions = getAggFunctions(aggregates);
auto aggOutputPos = getDataPos(aggregates, *outSchema);
auto aggregateInputInfos = getAggregateInputInfos(agg.getAllKeys(), aggregates, *inSchema);
auto sharedState =
make_shared<SimpleAggregateSharedState>(clientContext, aggFunctions, aggregateInputInfos);
auto printInfo = std::make_unique<SimpleAggregatePrintInfo>(aggregates);
auto aggregate = make_unique<SimpleAggregate>(sharedState, std::move(aggFunctions),
copyVector(aggregateInputInfos), std::move(prevOperator), getOperatorID(),
printInfo->copy());
aggregate->setDescriptor(std::make_unique<ResultSetDescriptor>(inSchema));
auto finalizer = std::make_unique<SimpleAggregateFinalize>(sharedState,
std::move(aggregateInputInfos), getOperatorID(), printInfo->copy());
finalizer->addChild(std::move(aggregate));
aggFunctions = getAggFunctions(aggregates);
auto scan = std::make_unique<SimpleAggregateScan>(sharedState,
AggregateScanInfo{std::move(aggOutputPos), getMoveAggResultToVectorFuncs(aggFunctions)},
getOperatorID(), printInfo->copy());
scan->addChild(std::move(finalizer));
return scan;
}
static FactorizedTableSchema getFactorizedTableSchema(const expression_vector& flatKeys,
const expression_vector& unFlatKeys, const expression_vector& payloads,
const std::vector<AggregateFunction>& aggregateFunctions) {
auto isUnFlat = false;
auto groupID = 0u;
auto tableSchema = FactorizedTableSchema();
for (auto& flatKey : flatKeys) {
auto size = LogicalTypeUtils::getRowLayoutSize(flatKey->dataType);
tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size));
}
for (auto& unFlatKey : unFlatKeys) {
auto size = LogicalTypeUtils::getRowLayoutSize(unFlatKey->dataType);
tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size));
}
for (auto& payload : payloads) {
auto size = LogicalTypeUtils::getRowLayoutSize(payload->dataType);
tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, size));
}
for (auto& aggregateFunc : aggregateFunctions) {
tableSchema.appendColumn(
ColumnSchema(isUnFlat, groupID, aggregateFunc.getAggregateStateSize()));
}
tableSchema.appendColumn(ColumnSchema(isUnFlat, groupID, sizeof(hash_t)));
return tableSchema;
}
std::unique_ptr<PhysicalOperator> PlanMapper::createDistinctHashAggregate(
const expression_vector& keys, const expression_vector& payloads, Schema* inSchema,
Schema* outSchema, std::unique_ptr<PhysicalOperator> prevOperator) {
return createHashAggregate(keys, payloads, expression_vector{} , inSchema,
outSchema, std::move(prevOperator));
}
std::unique_ptr<PhysicalOperator> PlanMapper::createHashAggregate(const expression_vector& keys,
const expression_vector& payloads, const expression_vector& aggregates, Schema* inSchema,
Schema* outSchema, std::unique_ptr<PhysicalOperator> prevOperator) {
auto aggFunctions = getAggFunctions(aggregates);
expression_vector allKeys;
allKeys.insert(allKeys.end(), keys.begin(), keys.end());
allKeys.insert(allKeys.end(), payloads.begin(), payloads.end());
auto aggregateInputInfos = getAggregateInputInfos(allKeys, aggregates, *inSchema);
auto flatKeys = getKeyExpressions(keys, *inSchema, true );
auto unFlatKeys = getKeyExpressions(keys, *inSchema, false );
std::vector<LogicalType> keyTypes, payloadTypes;
for (auto& key : flatKeys) {
keyTypes.push_back(key->getDataType().copy());
}
for (auto& key : unFlatKeys) {
keyTypes.push_back(key->getDataType().copy());
}
for (auto& payload : payloads) {
payloadTypes.push_back(payload->getDataType().copy());
}
auto tableSchema = getFactorizedTableSchema(flatKeys, unFlatKeys, payloads, aggFunctions);
HashAggregateInfo aggregateInfo{getDataPos(flatKeys, *inSchema),
getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)};
auto sharedState =
std::make_shared<HashAggregateSharedState>(clientContext, std::move(aggregateInfo),
aggFunctions, aggregateInputInfos, std::move(keyTypes), std::move(payloadTypes));
auto printInfo = std::make_unique<HashAggregatePrintInfo>(allKeys, aggregates);
auto aggregate = make_unique<HashAggregate>(sharedState, std::move(aggFunctions),
std::move(aggregateInputInfos), std::move(prevOperator), getOperatorID(),
printInfo->copy());
aggregate->setDescriptor(std::make_unique<ResultSetDescriptor>(inSchema));
expression_vector outputExpressions;
outputExpressions.insert(outputExpressions.end(), flatKeys.begin(), flatKeys.end());
outputExpressions.insert(outputExpressions.end(), unFlatKeys.begin(), unFlatKeys.end());
outputExpressions.insert(outputExpressions.end(), payloads.begin(), payloads.end());
auto aggOutputPos = getDataPos(aggregates, *outSchema);
auto finalizer =
std::make_unique<HashAggregateFinalize>(sharedState, getOperatorID(), printInfo->copy());
finalizer->addChild(std::move(aggregate));
aggFunctions = getAggFunctions(aggregates);
auto scan =
std::make_unique<HashAggregateScan>(sharedState, getDataPos(outputExpressions, *outSchema),
AggregateScanInfo{std::move(aggOutputPos), getMoveAggResultToVectorFuncs(aggFunctions)},
getOperatorID(), printInfo->copy());
scan->addChild(std::move(finalizer));
return scan;
}
} }