#include "binder/expression/expression_util.h"
#include "common/exception/binder.h"
#include "common/type_utils.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"
#include <format>
using namespace lbug::common;
using namespace lbug::binder;
namespace lbug {
namespace function {
struct ListContains {
template<typename T>
static void operation(common::list_entry_t& list, T& element, uint8_t& result,
common::ValueVector& listVector, common::ValueVector& elementVector,
common::ValueVector& resultVector) {
int64_t pos = 0;
ListPosition::operation(list, element, pos, listVector, elementVector, resultVector);
result = (pos != 0);
}
};
static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
auto scalarFunction = input.definition->ptrCast<ScalarFunction>();
std::vector<LogicalType> paramTypes;
LogicalType childType;
auto listExpr = input.arguments[0];
auto elementExpr = input.arguments[1];
if (ExpressionUtil::isEmptyList(*listExpr)) {
childType = elementExpr->getDataType().copy();
} else {
auto& listChildType = ListType::getChildType(listExpr->getDataType());
auto& elementType = elementExpr->getDataType();
if (!LogicalTypeUtils::tryGetMaxLogicalType(listChildType, elementType, childType)) {
throw BinderException(std::format("Cannot compare {} and {} in list_contains function.",
listChildType.toString(), elementType.toString()));
}
}
if (childType.getLogicalTypeID() == LogicalTypeID::ANY) {
childType = LogicalType::STRING();
}
auto listType = LogicalType::LIST(childType.copy());
paramTypes.push_back(listType.copy());
paramTypes.push_back(childType.copy());
TypeUtils::visit(childType.getPhysicalType(), [&scalarFunction]<typename T>(T) {
scalarFunction->execFunc =
ScalarFunction::BinaryExecListStructFunction<list_entry_t, T, uint8_t, ListContains>;
});
return std::make_unique<FunctionBindData>(std::move(paramTypes), LogicalType::BOOL());
}
function_set ListContainsFunction::getFunctionSet() {
function_set result;
auto function = std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL);
function->bindFunc = bindFunc;
result.push_back(std::move(function));
return result;
}
} }