#include "function/aggregate/percentile_disc.h"
#include <algorithm>
#include <cmath>
#include <vector>
#include "binder/expression/literal_expression.h"
#include "common/exception/binder.h"
#include "common/type_utils.h"
using namespace lbug::binder;
using namespace lbug::common;
namespace lbug {
namespace function {
struct PercentileDiscElement {
PercentileDiscElement* next = nullptr;
};
struct PercentileDiscState : public AggregateStateWithNull {
explicit PercentileDiscState(double percentile = 0) : percentile{percentile} {}
uint32_t getStateSize() const override { return sizeof(*this); }
void writeToVector(common::ValueVector* outputVector, uint64_t pos) override {
DASSERT(selectedValue != nullptr);
outputVector->copyFromRowData(pos, selectedValue);
}
PercentileDiscElement* head = nullptr;
PercentileDiscElement* tail = nullptr;
uint64_t count = 0;
double percentile;
uint8_t* selectedValue = nullptr;
};
static uint8_t* getElementValue(PercentileDiscElement* element) {
return reinterpret_cast<uint8_t*>(element) + sizeof(PercentileDiscElement);
}
static bool valueLess(const uint8_t* left, const uint8_t* right, LogicalTypeID typeID) {
switch (typeID) {
case LogicalTypeID::INT8:
return *reinterpret_cast<const int8_t*>(left) < *reinterpret_cast<const int8_t*>(right);
case LogicalTypeID::INT16:
return *reinterpret_cast<const int16_t*>(left) < *reinterpret_cast<const int16_t*>(right);
case LogicalTypeID::INT32:
return *reinterpret_cast<const int32_t*>(left) < *reinterpret_cast<const int32_t*>(right);
case LogicalTypeID::INT64:
case LogicalTypeID::SERIAL:
return *reinterpret_cast<const int64_t*>(left) < *reinterpret_cast<const int64_t*>(right);
case LogicalTypeID::UINT8:
return *reinterpret_cast<const uint8_t*>(left) < *reinterpret_cast<const uint8_t*>(right);
case LogicalTypeID::UINT16:
return *reinterpret_cast<const uint16_t*>(left) < *reinterpret_cast<const uint16_t*>(right);
case LogicalTypeID::UINT32:
return *reinterpret_cast<const uint32_t*>(left) < *reinterpret_cast<const uint32_t*>(right);
case LogicalTypeID::UINT64:
return *reinterpret_cast<const uint64_t*>(left) < *reinterpret_cast<const uint64_t*>(right);
case LogicalTypeID::FLOAT:
return *reinterpret_cast<const float*>(left) < *reinterpret_cast<const float*>(right);
case LogicalTypeID::DOUBLE:
return *reinterpret_cast<const double*>(left) < *reinterpret_cast<const double*>(right);
case LogicalTypeID::INT128:
return *reinterpret_cast<const int128_t*>(left) < *reinterpret_cast<const int128_t*>(right);
case LogicalTypeID::UINT128:
return *reinterpret_cast<const uint128_t*>(left) <
*reinterpret_cast<const uint128_t*>(right);
default:
UNREACHABLE_CODE;
}
}
static std::unique_ptr<AggregateState> initialize() {
return std::make_unique<PercentileDiscState>();
}
static void updateSingleValue(PercentileDiscState* state, common::ValueVector* input, uint32_t pos,
uint64_t multiplicity, common::InMemOverflowBuffer* overflowBuffer) {
auto valueSize = LogicalTypeUtils::getRowLayoutSize(input->dataType);
for (auto i = 0u; i < multiplicity; ++i) {
auto* element = reinterpret_cast<PercentileDiscElement*>(
overflowBuffer->allocateSpace(sizeof(PercentileDiscElement) + valueSize));
element->next = nullptr;
input->copyToRowData(pos, getElementValue(element), overflowBuffer);
if (state->tail) {
state->tail->next = element;
} else {
state->head = element;
}
state->tail = element;
state->count++;
state->isNull = false;
}
}
static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity,
common::InMemOverflowBuffer* overflowBuffer) {
DASSERT(!input->state->isFlat());
auto* state = reinterpret_cast<PercentileDiscState*>(state_);
input->forEachNonNull(
[&](auto pos) { updateSingleValue(state, input, pos, multiplicity, overflowBuffer); });
}
static void updatePos(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity,
uint32_t pos, common::InMemOverflowBuffer* overflowBuffer) {
updateSingleValue(reinterpret_cast<PercentileDiscState*>(state_), input, pos, multiplicity,
overflowBuffer);
}
static void combine(uint8_t* state_, uint8_t* otherState_,
common::InMemOverflowBuffer* ) {
auto* otherState = reinterpret_cast<PercentileDiscState*>(otherState_);
if (otherState->isNull) {
return;
}
auto* state = reinterpret_cast<PercentileDiscState*>(state_);
if (state->tail) {
state->tail->next = otherState->head;
} else {
state->head = otherState->head;
}
state->tail = otherState->tail;
state->count += otherState->count;
state->isNull = false;
otherState->head = nullptr;
otherState->tail = nullptr;
otherState->count = 0;
otherState->isNull = true;
}
static void finalize(uint8_t* state_, LogicalTypeID typeID) {
auto* state = reinterpret_cast<PercentileDiscState*>(state_);
if (state->isNull) {
return;
}
std::vector<uint8_t*> values;
values.reserve(state->count);
for (auto* element = state->head; element != nullptr; element = element->next) {
values.push_back(getElementValue(element));
}
std::sort(values.begin(), values.end(),
[typeID](auto left, auto right) { return valueLess(left, right, typeID); });
auto rawIndex = static_cast<int64_t>(std::ceil(state->percentile * values.size())) - 1;
auto index = std::clamp<int64_t>(rawIndex, 0, static_cast<int64_t>(values.size()) - 1);
state->selectedValue = values[index];
}
static double bindPercentile(const ScalarBindFuncInput& input) {
if (input.arguments.size() != 2) {
throw BinderException("percentileDisc requires exactly two arguments.");
}
auto literalExpr = dynamic_cast<LiteralExpression*>(input.arguments[1].get());
if (literalExpr == nullptr) {
throw BinderException("Second parameter of percentileDisc must be a literal.");
}
auto percentile = literalExpr->getValue().getValue<double>();
if (percentile < 0 || percentile > 1) {
throw BinderException("percentileDisc percentile must be between 0.0 and 1.0.");
}
return percentile;
}
static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
auto percentile = bindPercentile(input);
auto typeID = input.arguments[0]->dataType.getLogicalTypeID();
auto* aggregateFunction = input.definition->ptrCast<AggregateFunction>();
aggregateFunction->initializeFunc = [percentile]() {
return std::make_unique<PercentileDiscState>(percentile);
};
aggregateFunction->finalizeFunc = [typeID](auto state) { finalize(state, typeID); };
aggregateFunction->initialNullAggregateState =
aggregateFunction->createInitialNullAggregateState();
return FunctionBindData::getSimpleBindData(input.arguments, input.arguments[0]->dataType);
}
function_set AggregatePercentileDiscFunction::getFunctionSet() {
function_set result;
for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
result.push_back(std::make_unique<AggregateFunction>(
name, std::vector<LogicalTypeID>{typeID, LogicalTypeID::DOUBLE}, typeID, initialize,
updateAll, updatePos, combine, [](auto) {}, isDistinct, bindFunc));
}
}
return result;
}
} }