#include "processor/operator/aggregate/hash_aggregate.h"
#include <memory>
#include "binder/expression/expression_util.h"
#include "common/assert.h"
#include "common/types/types.h"
#include "main/client_context.h"
#include "processor/execution_context.h"
#include "processor/operator/aggregate/aggregate_hash_table.h"
#include "processor/operator/aggregate/aggregate_input.h"
#include "processor/operator/aggregate/base_aggregate.h"
#include "processor/result/factorized_table_schema.h"
#include "storage/buffer_manager/memory_manager.h"
using namespace lbug::common;
using namespace lbug::function;
using namespace lbug::storage;
namespace lbug {
namespace processor {
std::string HashAggregatePrintInfo::toString() const {
std::string result = "";
result += "Group By: ";
result += binder::ExpressionUtil::toString(keys);
if (!aggregates.empty()) {
result += ", Aggregates: ";
result += binder::ExpressionUtil::toString(aggregates);
}
if (limitNum != UINT64_MAX) {
result += ", Distinct Limit: " + std::to_string(limitNum);
}
return result;
}
HashAggregateInfo::HashAggregateInfo(std::vector<DataPos> flatKeysPos,
std::vector<DataPos> unFlatKeysPos, std::vector<DataPos> dependentKeysPos,
FactorizedTableSchema tableSchema)
: flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)},
dependentKeysPos{std::move(dependentKeysPos)}, tableSchema{std::move(tableSchema)} {}
HashAggregateInfo::HashAggregateInfo(const HashAggregateInfo& other)
: flatKeysPos{other.flatKeysPos}, unFlatKeysPos{other.unFlatKeysPos},
dependentKeysPos{other.dependentKeysPos}, tableSchema{other.tableSchema.copy()} {}
HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context,
HashAggregateInfo hashAggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions,
std::span<AggregateInfo> aggregateInfos, std::vector<LogicalType> keyTypes,
std::vector<LogicalType> payloadTypes)
: BaseAggregateSharedState{aggregateFunctions, getNumPartitionsForParallelism(context)},
aggInfo{std::move(hashAggInfo)}, limitNumber{common::INVALID_LIMIT},
memoryManager{MemoryManager::Get(*context)},
globalPartitions{getNumPartitionsForParallelism(context)} {
std::vector<LogicalType> distinctAggregateKeyTypes;
for (auto& aggInfo : aggregateInfos) {
distinctAggregateKeyTypes.push_back(aggInfo.distinctAggKeyType.copy());
}
for (size_t i = 0; i < this->aggInfo.tableSchema.getNumColumns() - 1; i++) {
this->aggInfo.tableSchema.setMayContainsNullsToTrue(i);
}
auto& partition = globalPartitions[0];
partition.queue = std::make_unique<HashTableQueue>(MemoryManager::Get(*context),
this->aggInfo.tableSchema.copy());
partition.hashTable = std::make_unique<AggregateHashTable>(*MemoryManager::Get(*context),
std::move(keyTypes), std::move(payloadTypes), aggregateFunctions, distinctAggregateKeyTypes,
0, this->aggInfo.tableSchema.copy());
for (size_t functionIdx = 0; functionIdx < aggregateFunctions.size(); functionIdx++) {
auto& function = aggregateFunctions[functionIdx];
if (function.isFunctionDistinct()) {
auto distinctTableSchema = FactorizedTableSchema();
for (size_t i = 0;
i < this->aggInfo.flatKeysPos.size() + this->aggInfo.unFlatKeysPos.size(); i++) {
distinctTableSchema.appendColumn(this->aggInfo.tableSchema.getColumn(i)->copy());
distinctTableSchema.setMayContainsNullsToTrue(i);
}
distinctTableSchema.appendColumn(ColumnSchema(false , 0 ,
LogicalTypeUtils::getRowLayoutSize(
aggregateInfos[functionIdx].distinctAggKeyType)));
distinctTableSchema.setMayContainsNullsToTrue(distinctTableSchema.getNumColumns() - 1);
distinctTableSchema.appendColumn(
ColumnSchema(false , 0 , sizeof(hash_t)));
partition.distinctTableQueues.emplace_back(std::make_unique<HashTableQueue>(
MemoryManager::Get(*context), std::move(distinctTableSchema)));
} else {
partition.distinctTableQueues.emplace_back();
}
}
for (size_t i = 1; i < globalPartitions.size(); i++) {
globalPartitions[i].queue = std::make_unique<HashTableQueue>(MemoryManager::Get(*context),
this->aggInfo.tableSchema.copy());
globalPartitions[i].distinctTableQueues.resize(partition.distinctTableQueues.size());
std::transform(partition.distinctTableQueues.begin(), partition.distinctTableQueues.end(),
globalPartitions[i].distinctTableQueues.begin(), [&](auto& q) {
if (q.get() != nullptr) {
return q->copy();
} else {
return std::unique_ptr<HashTableQueue>();
}
});
}
}
std::pair<uint64_t, uint64_t> HashAggregateSharedState::getNextRangeToRead() {
std::unique_lock lck{mtx};
auto startOffset = currentOffset.load();
auto numTuples = getNumTuples();
if (startOffset >= numTuples) {
return std::make_pair(startOffset, startOffset);
}
auto [table, tableStartOffset] = getPartitionForOffset(startOffset);
auto range = std::min(std::min(DEFAULT_VECTOR_CAPACITY, numTuples - startOffset),
table->getNumTuples() + tableStartOffset - startOffset);
currentOffset += range;
return std::make_pair(startOffset, startOffset + range);
}
uint64_t HashAggregateSharedState::getNumTuples() const {
uint64_t numTuples = 0;
for (auto& partition : globalPartitions) {
numTuples += partition.hashTable->getNumEntries();
}
return numTuples;
}
void HashAggregateSharedState::finalizePartitions() {
BaseAggregateSharedState::finalizePartitions(globalPartitions, [&](auto& partition) {
if (!partition.hashTable) {
partition.hashTable = std::make_unique<AggregateHashTable>(
globalPartitions[0].hashTable->createEmptyCopy());
}
for (size_t i = 0; i < partition.distinctTableQueues.size(); i++) {
if (partition.distinctTableQueues[i]) {
partition.distinctTableQueues[i]->mergeInto(
*partition.hashTable->getDistinctHashTable(i));
}
}
partition.queue->mergeInto(*partition.hashTable);
partition.hashTable->mergeDistinctAggregateInfo();
partition.hashTable->finalizeAggregateStates();
});
}
std::tuple<const FactorizedTable*, offset_t> HashAggregateSharedState::getPartitionForOffset(
offset_t offset) const {
auto factorizedTableStartOffset = 0;
auto partitionIdx = 0;
const auto* table = globalPartitions[partitionIdx].hashTable->getFactorizedTable();
while (factorizedTableStartOffset + table->getNumTuples() <= offset) {
factorizedTableStartOffset += table->getNumTuples();
table = globalPartitions[++partitionIdx].hashTable->getFactorizedTable();
}
return std::make_tuple(table, factorizedTableStartOffset);
}
void HashAggregateSharedState::scan(std::span<uint8_t*> entries,
std::vector<common::ValueVector*>& keyVectors, offset_t startOffset, offset_t numTuplesToScan,
std::vector<uint32_t>& columnIndices) {
auto [table, tableStartOffset] = getPartitionForOffset(startOffset);
DASSERT(startOffset - tableStartOffset + numTuplesToScan <= table->getNumTuples());
for (size_t pos = 0; pos < numTuplesToScan; pos++) {
auto posInTable = startOffset + pos - tableStartOffset;
entries[pos] = table->getTuple(posInTable);
}
table->lookup(keyVectors, columnIndices, entries.data(), 0, numTuplesToScan);
DASSERT(true);
}
void HashAggregateSharedState::assertFinalized() const {
RUNTIME_CHECK(for (const auto& partition
: globalPartitions) {
DASSERT(partition.finalized);
DASSERT(partition.queue->empty());
});
}
void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, ResultSet& resultSet,
main::ClientContext* context, std::vector<function::AggregateFunction>& aggregateFunctions,
std::vector<common::LogicalType> distinctKeyTypes) {
auto& info = sharedState->getAggregateInfo();
std::vector<LogicalType> keyDataTypes;
for (auto& pos : info.flatKeysPos) {
auto vector = resultSet.getValueVector(pos).get();
keyVectors.push_back(vector);
keyDataTypes.push_back(vector->dataType.copy());
}
for (auto& pos : info.unFlatKeysPos) {
auto vector = resultSet.getValueVector(pos).get();
keyVectors.push_back(vector);
keyDataTypes.push_back(vector->dataType.copy());
leadingState = vector->state.get();
}
if (leadingState == nullptr) {
leadingState = keyVectors.front()->state.get();
}
std::vector<LogicalType> payloadDataTypes;
for (auto& pos : info.dependentKeysPos) {
auto vector = resultSet.getValueVector(pos).get();
dependentKeyVectors.push_back(vector);
payloadDataTypes.push_back(vector->dataType.copy());
}
aggregateHashTable = std::make_unique<PartitioningAggregateHashTable>(sharedState,
*MemoryManager::Get(*context), std::move(keyDataTypes), std::move(payloadDataTypes),
aggregateFunctions, std::move(distinctKeyTypes), info.tableSchema.copy());
}
uint64_t HashAggregateLocalState::append(const std::vector<AggregateInput>& aggregateInputs,
uint64_t multiplicity) const {
return aggregateHashTable->append(keyVectors, dependentKeyVectors, leadingState,
aggregateInputs, multiplicity);
}
void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
BaseAggregate::initLocalStateInternal(resultSet, context);
std::vector<LogicalType> distinctAggKeyTypes;
for (auto& info : aggInfos) {
distinctAggKeyTypes.push_back(info.distinctAggKeyType.copy());
}
localState.init(common::dynamic_cast_checked<HashAggregateSharedState*>(sharedState.get()),
*resultSet, context->clientContext, aggregateFunctions, std::move(distinctAggKeyTypes));
}
void HashAggregate::executeInternal(ExecutionContext* context) {
while (children[0]->getNextTuple(context)) {
const auto numAppendedFlatTuples = localState.append(aggInputs, resultSet->multiplicity);
metrics->numOutputTuple.increase(numAppendedFlatTuples);
if (localState.aggregateHashTable->getNumEntries() >=
getSharedStateReference().getLimitNumber()) {
break;
}
}
localState.aggregateHashTable->mergeIfFull(0 , true );
}
} }