lbug 0.17.0

An in-process property graph database management system built for query speed and scalability
#include "function/aggregate/percentile_disc.h"

#include <algorithm>
#include <cmath>
#include <vector>

#include "binder/expression/literal_expression.h"
#include "common/exception/binder.h"
#include "common/type_utils.h"

using namespace lbug::binder;
using namespace lbug::common;

namespace lbug {
namespace function {

struct PercentileDiscElement {
    PercentileDiscElement* next = nullptr;
};

struct PercentileDiscState : public AggregateStateWithNull {
    explicit PercentileDiscState(double percentile = 0) : percentile{percentile} {}

    uint32_t getStateSize() const override { return sizeof(*this); }

    void writeToVector(common::ValueVector* outputVector, uint64_t pos) override {
        DASSERT(selectedValue != nullptr);
        outputVector->copyFromRowData(pos, selectedValue);
    }

    PercentileDiscElement* head = nullptr;
    PercentileDiscElement* tail = nullptr;
    uint64_t count = 0;
    double percentile;
    uint8_t* selectedValue = nullptr;
};

static uint8_t* getElementValue(PercentileDiscElement* element) {
    return reinterpret_cast<uint8_t*>(element) + sizeof(PercentileDiscElement);
}

static bool valueLess(const uint8_t* left, const uint8_t* right, LogicalTypeID typeID) {
    switch (typeID) {
    case LogicalTypeID::INT8:
        return *reinterpret_cast<const int8_t*>(left) < *reinterpret_cast<const int8_t*>(right);
    case LogicalTypeID::INT16:
        return *reinterpret_cast<const int16_t*>(left) < *reinterpret_cast<const int16_t*>(right);
    case LogicalTypeID::INT32:
        return *reinterpret_cast<const int32_t*>(left) < *reinterpret_cast<const int32_t*>(right);
    case LogicalTypeID::INT64:
    case LogicalTypeID::SERIAL:
        return *reinterpret_cast<const int64_t*>(left) < *reinterpret_cast<const int64_t*>(right);
    case LogicalTypeID::UINT8:
        return *reinterpret_cast<const uint8_t*>(left) < *reinterpret_cast<const uint8_t*>(right);
    case LogicalTypeID::UINT16:
        return *reinterpret_cast<const uint16_t*>(left) < *reinterpret_cast<const uint16_t*>(right);
    case LogicalTypeID::UINT32:
        return *reinterpret_cast<const uint32_t*>(left) < *reinterpret_cast<const uint32_t*>(right);
    case LogicalTypeID::UINT64:
        return *reinterpret_cast<const uint64_t*>(left) < *reinterpret_cast<const uint64_t*>(right);
    case LogicalTypeID::FLOAT:
        return *reinterpret_cast<const float*>(left) < *reinterpret_cast<const float*>(right);
    case LogicalTypeID::DOUBLE:
        return *reinterpret_cast<const double*>(left) < *reinterpret_cast<const double*>(right);
    case LogicalTypeID::INT128:
        return *reinterpret_cast<const int128_t*>(left) < *reinterpret_cast<const int128_t*>(right);
    case LogicalTypeID::UINT128:
        return *reinterpret_cast<const uint128_t*>(left) <
               *reinterpret_cast<const uint128_t*>(right);
    default:
        UNREACHABLE_CODE;
    }
}

static std::unique_ptr<AggregateState> initialize() {
    return std::make_unique<PercentileDiscState>();
}

static void updateSingleValue(PercentileDiscState* state, common::ValueVector* input, uint32_t pos,
    uint64_t multiplicity, common::InMemOverflowBuffer* overflowBuffer) {
    auto valueSize = LogicalTypeUtils::getRowLayoutSize(input->dataType);
    for (auto i = 0u; i < multiplicity; ++i) {
        auto* element = reinterpret_cast<PercentileDiscElement*>(
            overflowBuffer->allocateSpace(sizeof(PercentileDiscElement) + valueSize));
        element->next = nullptr;
        input->copyToRowData(pos, getElementValue(element), overflowBuffer);
        if (state->tail) {
            state->tail->next = element;
        } else {
            state->head = element;
        }
        state->tail = element;
        state->count++;
        state->isNull = false;
    }
}

static void updateAll(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity,
    common::InMemOverflowBuffer* overflowBuffer) {
    DASSERT(!input->state->isFlat());
    auto* state = reinterpret_cast<PercentileDiscState*>(state_);
    input->forEachNonNull(
        [&](auto pos) { updateSingleValue(state, input, pos, multiplicity, overflowBuffer); });
}

static void updatePos(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity,
    uint32_t pos, common::InMemOverflowBuffer* overflowBuffer) {
    updateSingleValue(reinterpret_cast<PercentileDiscState*>(state_), input, pos, multiplicity,
        overflowBuffer);
}

static void combine(uint8_t* state_, uint8_t* otherState_,
    common::InMemOverflowBuffer* /*overflowBuffer*/) {
    auto* otherState = reinterpret_cast<PercentileDiscState*>(otherState_);
    if (otherState->isNull) {
        return;
    }
    auto* state = reinterpret_cast<PercentileDiscState*>(state_);
    if (state->tail) {
        state->tail->next = otherState->head;
    } else {
        state->head = otherState->head;
    }
    state->tail = otherState->tail;
    state->count += otherState->count;
    state->isNull = false;
    otherState->head = nullptr;
    otherState->tail = nullptr;
    otherState->count = 0;
    otherState->isNull = true;
}

static void finalize(uint8_t* state_, LogicalTypeID typeID) {
    auto* state = reinterpret_cast<PercentileDiscState*>(state_);
    if (state->isNull) {
        return;
    }
    std::vector<uint8_t*> values;
    values.reserve(state->count);
    for (auto* element = state->head; element != nullptr; element = element->next) {
        values.push_back(getElementValue(element));
    }
    std::sort(values.begin(), values.end(),
        [typeID](auto left, auto right) { return valueLess(left, right, typeID); });
    auto rawIndex = static_cast<int64_t>(std::ceil(state->percentile * values.size())) - 1;
    auto index = std::clamp<int64_t>(rawIndex, 0, static_cast<int64_t>(values.size()) - 1);
    state->selectedValue = values[index];
}

static double bindPercentile(const ScalarBindFuncInput& input) {
    if (input.arguments.size() != 2) {
        throw BinderException("percentileDisc requires exactly two arguments.");
    }
    auto literalExpr = dynamic_cast<LiteralExpression*>(input.arguments[1].get());
    if (literalExpr == nullptr) {
        throw BinderException("Second parameter of percentileDisc must be a literal.");
    }
    auto percentile = literalExpr->getValue().getValue<double>();
    if (percentile < 0 || percentile > 1) {
        throw BinderException("percentileDisc percentile must be between 0.0 and 1.0.");
    }
    return percentile;
}

static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
    auto percentile = bindPercentile(input);
    auto typeID = input.arguments[0]->dataType.getLogicalTypeID();
    auto* aggregateFunction = input.definition->ptrCast<AggregateFunction>();
    aggregateFunction->initializeFunc = [percentile]() {
        return std::make_unique<PercentileDiscState>(percentile);
    };
    aggregateFunction->finalizeFunc = [typeID](auto state) { finalize(state, typeID); };
    aggregateFunction->initialNullAggregateState =
        aggregateFunction->createInitialNullAggregateState();
    return FunctionBindData::getSimpleBindData(input.arguments, input.arguments[0]->dataType);
}

function_set AggregatePercentileDiscFunction::getFunctionSet() {
    function_set result;
    for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) {
        for (auto isDistinct : std::vector<bool>{true, false}) {
            result.push_back(std::make_unique<AggregateFunction>(
                name, std::vector<LogicalTypeID>{typeID, LogicalTypeID::DOUBLE}, typeID, initialize,
                updateAll, updatePos, combine, [](auto) {}, isDistinct, bindFunc));
        }
    }
    return result;
}

} // namespace function
} // namespace lbug