#include <iostream>
#include "queryscript/include/duckdb-extra.hpp"
#include "queryscript/src/runtime/duckdb/engine.rs.h"
std::unique_ptr<ArrowArrayStreamWrapper> new_array_stream_wrapper(uintptr_t data, duckdb::ArrowStreamParameters ¶meters)
{
auto ret = duckdb::make_unique<ArrowArrayStreamWrapper>();
rust_build_array_stream(
(uint32_t *)data,
parameters.projected_columns.columns,
(uint32_t *)&ret->arrow_array_stream);
return ret;
}
uint32_t *get_create_stream_fn()
{
return (uint32_t *)new_array_stream_wrapper;
}
Value *duckdb_create_pointer(uint32_t *value)
{
auto val = duckdb::Value::POINTER((uintptr_t)value);
return (Value *)new duckdb::Value(val);
}
class BuiltinFunctions;
namespace duckdb
{
enum class ArrowVariableSizeType : uint8_t { FIXED_SIZE = 0, NORMAL = 1, SUPER_SIZE = 2 };
enum class ArrowDateTimeType : uint8_t {
MILLISECONDS = 0,
MICROSECONDS = 1,
NANOSECONDS = 2,
SECONDS = 3,
DAYS = 4,
MONTHS = 5
};
struct ArrowConvertData {
ArrowConvertData(LogicalType type) : dictionary_type(type) {};
ArrowConvertData() {};
LogicalType dictionary_type;
vector<pair<ArrowVariableSizeType, idx_t>> variable_sz_type;
vector<ArrowDateTimeType> date_time_precision;
};
struct ArrowTableFunction
{
public:
static void RegisterFunction(BuiltinFunctions &set);
static unique_ptr<FunctionData> ArrowScanBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names);
static unique_ptr<GlobalTableFunctionState> ArrowScanInitGlobal(ClientContext &context,
TableFunctionInitInput &input);
static unique_ptr<LocalTableFunctionState> ArrowScanInitLocal(ExecutionContext &context,
TableFunctionInitInput &input,
GlobalTableFunctionState *global_state);
static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output);
static idx_t ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data);
static idx_t ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p,
LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state);
static unique_ptr<NodeStatistics> ArrowScanCardinality(ClientContext &context, const FunctionData *bind_data);
static double ArrowProgress(ClientContext &context, const FunctionData *bind_data,
const GlobalTableFunctionState *global_state);
static void RenameArrowColumns(vector<string> &names);
static LogicalType GetArrowLogicalType(ArrowSchema &schema,
std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data,
idx_t col_idx);
};
}
void init_arrow_scan(uint32_t *connection_ptr)
{
using namespace duckdb;
TableFunction arrow("arrow_scan_qs", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER},
ArrowTableFunction::ArrowScanFunction, ArrowTableFunction::ArrowScanBind,
ArrowTableFunction::ArrowScanInitGlobal, ArrowTableFunction::ArrowScanInitLocal);
arrow.cardinality = ArrowTableFunction::ArrowScanCardinality;
arrow.get_batch_index = ArrowTableFunction::ArrowGetBatchIndex;
arrow.projection_pushdown = true;
arrow.filter_pushdown = false; arrow.filter_prune = true;
auto tf_info = CreateTableFunctionInfo(move(arrow));
tf_info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT;
auto con = (duckdb::Connection *)connection_ptr;
con->context->RunFunctionInTransaction([&]() {
auto &catalog = duckdb::Catalog::GetSystemCatalog(*con->context);
catalog.CreateTableFunction(*con->context, &tf_info);
});
}