#include "expression_evaluator/lambda_evaluator.h"
#include "binder/expression/lambda_expression.h"
#include "common/exception/runtime.h"
#include "expression_evaluator/expression_evaluator_visitor.h"
#include "expression_evaluator/list_slice_info.h"
#include "function/list/vector_list_functions.h"
#include "main/client_context.h"
#include "parser/expression/parsed_lambda_expression.h"
#include "storage/buffer_manager/memory_manager.h"
#include <format>
using namespace lbug::common;
using namespace lbug::processor;
using namespace lbug::main;
using namespace lbug::storage;
namespace lbug {
namespace evaluator {
void ListLambdaEvaluator::init(const ResultSet& resultSet, ClientContext* clientContext) {
for (auto& child : children) {
child->init(resultSet, clientContext);
}
DASSERT(children.size() == 1);
auto listInputVector = children[0]->resultVector.get();
auto collector = LambdaParamEvaluatorCollector();
collector.visit(lambdaRootEvaluator.get());
auto evaluators = collector.getEvaluators();
auto lambdaVarState = std::make_shared<DataChunkState>();
memoryManager = MemoryManager::Get(*clientContext);
for (auto& evaluator : evaluators) {
evaluator->resultVector =
listLambdaType != ListLambdaType::LIST_REDUCE ?
ListVector::getSharedDataVector(listInputVector) :
std::make_shared<ValueVector>(
ListType::getChildType(listInputVector->dataType).copy(), memoryManager);
evaluator->resultVector->state = lambdaVarState;
lambdaParamEvaluators.push_back(evaluator->ptrCast<LambdaParamEvaluator>());
}
lambdaRootEvaluator->init(resultSet, clientContext);
resolveResultVector(resultSet, memoryManager);
params.push_back(children[0]->resultVector);
params.push_back(lambdaRootEvaluator->resultVector);
auto paramIndices = getParamIndices();
bindData = ListLambdaBindData{lambdaParamEvaluators, paramIndices, lambdaRootEvaluator.get()};
}
void ListLambdaEvaluator::evaluateInternal() {
auto* inputVector = params[0].get();
if (resultVector->dataType.getPhysicalType() == PhysicalTypeID::LIST) {
ListVector::resizeDataVector(resultVector.get(),
ListVector::getDataVectorSize(inputVector));
}
ListSliceInfo sliceInfo{inputVector};
bindData.sliceInfo = &sliceInfo;
auto selVectors = SelectionVector::fromValueVectors(params);
do {
sliceInfo.nextSlice();
execFunc(params, selVectors, *resultVector, resultVector->getSelVectorPtr(), &bindData);
} while (!sliceInfo.done());
}
void ListLambdaEvaluator::evaluate() {
DASSERT(children.size() == 1);
children[0]->evaluate();
evaluateInternal();
}
bool ListLambdaEvaluator::selectInternal(SelectionVector& selVector) {
DASSERT(children.size() == 1);
children[0]->evaluate();
evaluateInternal();
return updateSelectedPos(selVector);
}
void ListLambdaEvaluator::resolveResultVector(const ResultSet&, MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(expression->getDataType().copy(), memoryManager);
resultVector->state = children[0]->resultVector->state;
isResultFlat_ = children[0]->isResultFlat();
}
std::vector<idx_t> ListLambdaEvaluator::getParamIndices() {
const auto& paramNames = getExpression()
->getChild(1)
->constCast<binder::LambdaExpression>()
.getParsedLambdaExpr()
->constCast<parser::ParsedLambdaExpression>()
.getVarNames();
std::vector<idx_t> index(lambdaParamEvaluators.size());
for (idx_t i = 0; i < lambdaParamEvaluators.size(); i++) {
auto paramName = lambdaParamEvaluators[i]->getVarName();
auto it = std::find(paramNames.begin(), paramNames.end(), paramName);
if (it != paramNames.end()) {
index[i] = it - paramNames.begin();
} else {
throw RuntimeException(std::format("Lambda paramName {} cannot found.", paramName));
}
}
return index;
}
ListLambdaType ListLambdaEvaluator::checkListLambdaTypeWithFunctionName(std::string functionName) {
if (0 == functionName.compare(function::ListTransformFunction::name)) {
return ListLambdaType::LIST_TRANSFORM;
} else if (0 == functionName.compare(function::ListFilterFunction::name)) {
return ListLambdaType::LIST_FILTER;
} else if (0 == functionName.compare(function::ListReduceFunction::name)) {
return ListLambdaType::LIST_REDUCE;
} else {
return ListLambdaType::DEFAULT;
}
}
} }