from typing import *
import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet
import inspect
import traceback
import json
from concurrent.futures import ThreadPoolExecutor
import concurrent
from decimal import Decimal
import signal
class UserDefinedFunction:
_name: str
_input_schema: pa.Schema
_result_schema: pa.Schema
_executor: Optional[ThreadPoolExecutor] = None
def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
return iter([])
class ScalarFunction(UserDefinedFunction):
_batch: bool
def __init__(self, *args, **kwargs):
self._batch = kwargs.pop("batch", False)
io_threads = kwargs.pop("io_threads", None) or 1
if not self._batch and io_threads > 1:
self._executor = ThreadPoolExecutor(max_workers=io_threads)
super().__init__(*args, **kwargs)
def eval(self, *args) -> Any:
pass
def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
inputs = [[v.as_py() for v in array] for array in batch]
if self._batch:
results = self._func(*inputs)
else:
if self._executor:
results = list(
self._executor.map(
lambda args: self._func(*args), (
[col[i] for col in inputs]
for i in range(batch.num_rows)
),
)
)
else:
results = [
self.eval(*[col[i] for col in inputs])
for i in range(batch.num_rows)
]
array = _to_arrow_array(results, self._result_schema.types[0])
yield pa.RecordBatch.from_arrays([array], schema=self._result_schema)
def _to_arrow_array(column: List, type: pa.DataType) -> pa.Array:
if pa.types.is_list(type):
offsets = [0]
values = []
mask = []
for array in column:
if array is not None:
values.extend(array)
offsets.append(len(values))
mask.append(array is None)
offsets = pa.array(offsets, type=pa.int32())
values = _to_arrow_array(values, type.value_type)
mask = pa.array(mask, type=pa.bool_())
return pa.ListArray.from_arrays(offsets, values, mask=mask)
if pa.types.is_struct(type):
arrays = [
_to_arrow_array(
[v.get(field.name) if v is not None else None for v in column],
field.type,
)
for field in type
]
mask = pa.array([v is None for v in column], type=pa.bool_())
return pa.StructArray.from_arrays(arrays, fields=type, mask=mask)
if type.equals(JsonType()):
s = pa.array(
[json.dumps(v) if v is not None else None for v in column], type=pa.string()
)
return pa.ExtensionArray.from_storage(JsonType(), s)
if type.equals(DecimalType()):
s = pa.array(
[_decimal_to_str(v) if v is not None else None for v in column],
type=pa.string(),
)
return pa.ExtensionArray.from_storage(DecimalType(), s)
return pa.array(column, type=type)
def _decimal_to_str(v: Decimal) -> str:
if not isinstance(v, Decimal):
raise ValueError(f"Expected Decimal, got {v}")
return format(v, "f")
class TableFunction(UserDefinedFunction):
BATCH_SIZE = 1024
def eval(self, *args) -> Iterator:
yield
def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
class RecordBatchBuilder:
schema: pa.Schema
columns: List[List]
def __init__(self, schema: pa.Schema):
self.schema = schema
self.columns = [[] for _ in self.schema.types]
def len(self) -> int:
return len(self.columns[0])
def append(self, index: int, value: Any):
self.columns[0].append(index)
self.columns[1].append(value)
def build(self) -> pa.RecordBatch:
arrays = [
pa.array(col, type)
for col, type in zip(self.columns, self.schema.types)
]
self.columns = [[] for _ in self.schema.types]
return pa.RecordBatch.from_arrays(arrays, schema=self.schema)
builder = RecordBatchBuilder(self._result_schema)
for row_index in range(batch.num_rows):
row = tuple(column[row_index].as_py() for column in batch)
for result in self.eval(*row):
builder.append(row_index, result)
if builder.len() == self.BATCH_SIZE:
yield builder.build()
if builder.len() != 0:
yield builder.build()
class UserDefinedScalarFunctionWrapper(ScalarFunction):
_func: Callable
def __init__(self, func, input_types, result_type, name=None, **kwargs):
self._func = func
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
)
self._input_schema = pa.schema(
zip(
inspect.getfullargspec(func)[0],
[_to_data_type(t) for t in _to_list(input_types)],
)
)
self._result_schema = pa.schema([(self._name, _to_data_type(result_type))])
super().__init__(**kwargs)
def __call__(self, *args):
return self._func(*args)
def eval(self, *args):
return self._func(*args)
class UserDefinedTableFunctionWrapper(TableFunction):
_func: Callable
def __init__(self, func, input_types, result_types, name=None):
self._func = func
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
)
self._input_schema = pa.schema(
zip(
inspect.getfullargspec(func)[0],
[_to_data_type(t) for t in _to_list(input_types)],
)
)
self._result_schema = pa.schema(
[
("row", pa.int32()),
(
self._name,
(
pa.struct([("", _to_data_type(t)) for t in result_types])
if isinstance(result_types, list)
else _to_data_type(result_types)
),
),
]
)
def __call__(self, *args):
return self._func(*args)
def eval(self, *args):
return self._func(*args)
def _to_list(x):
if isinstance(x, list):
return x
else:
return [x]
def udf(
input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
result_type: Union[str, pa.DataType],
name: Optional[str] = None,
io_threads: Optional[int] = None,
batch: bool = False,
) -> Callable:
return lambda f: UserDefinedScalarFunctionWrapper(
f, input_types, result_type, name, io_threads=io_threads, batch=batch
)
def udtf(
input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
name: Optional[str] = None,
) -> Callable:
return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name)
class UdfServer(pa.flight.FlightServerBase):
_location: str
_functions: Dict[str, UserDefinedFunction]
def __init__(self, location="0.0.0.0:8815", **kwargs):
super(UdfServer, self).__init__("grpc://" + location, **kwargs)
self._location = location
self._functions = {}
def get_flight_info(self, context, descriptor):
udf = self._functions[descriptor.path[0].decode("utf-8")]
return self._make_flight_info(udf)
def _make_flight_info(self, udf: UserDefinedFunction) -> pa.flight.FlightInfo:
full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema))
return pa.flight.FlightInfo(
schema=full_schema,
descriptor=pa.flight.FlightDescriptor.for_path(udf._name),
endpoints=[],
total_records=len(udf._input_schema),
total_bytes=0,
)
def list_flights(self, context, criteria):
return [self._make_flight_info(udf) for udf in self._functions.values()]
def add_function(self, udf: UserDefinedFunction):
name = udf._name
if name in self._functions:
raise ValueError("Function already exists: " + name)
print(f"added function: {name}")
self._functions[name] = udf
def do_exchange(self, context, descriptor, reader, writer):
udf = self._functions[descriptor.path[0].decode("utf-8")]
writer.begin(udf._result_schema)
try:
for batch in reader:
for output_batch in udf.eval_batch(batch.data):
writer.write_batch(output_batch)
except Exception as e:
print(traceback.print_exc())
raise e
def do_action(self, context, action):
if action.type == "protocol_version":
yield b"\x02"
else:
raise NotImplementedError
def serve(self):
print(f"listening on {self._location}")
signal.signal(signal.SIGTERM, lambda s, f: self.shutdown())
super(UdfServer, self).serve()
class JsonScalar(pa.ExtensionScalar):
def as_py(self):
return json.loads(self.value.as_py()) if self.value is not None else None
class JsonType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.string(), "arrowudf.json")
def __arrow_ext_serialize__(self):
return b""
@classmethod
def __arrow_ext_deserialize__(self, storage_type, serialized):
return JsonType()
def __arrow_ext_scalar_class__(self):
return JsonScalar
class DecimalScalar(pa.ExtensionScalar):
def as_py(self):
return Decimal(self.value.as_py()) if self.value is not None else None
class DecimalType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.string(), "arrowudf.decimal")
def __arrow_ext_serialize__(self):
return b""
@classmethod
def __arrow_ext_deserialize__(self, storage_type, serialized):
return DecimalType()
def __arrow_ext_scalar_class__(self):
return DecimalScalar
pa.register_extension_type(JsonType())
pa.register_extension_type(DecimalType())
def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType:
if isinstance(t, str):
return _string_to_data_type(t)
else:
return t
def _string_to_data_type(type: str):
t = type.upper()
if t.endswith("[]"):
return pa.list_(_string_to_data_type(type[:-2]))
elif t.startswith("STRUCT"):
type_list = type[7:-1] fields = []
start = 0
depth = 0
for i, c in enumerate(type_list):
if c == "<":
depth += 1
elif c == ">":
depth -= 1
elif c == "," and depth == 0:
name, t = type_list[start:i].split(":", maxsplit=1)
name = name.strip()
t = t.strip()
fields.append(pa.field(name, _string_to_data_type(t)))
start = i + 1
if ":" in type_list[start:].strip():
name, t = type_list[start:].split(":", maxsplit=1)
name = name.strip()
t = t.strip()
fields.append(pa.field(name, _string_to_data_type(t)))
return pa.struct(fields)
elif t in ("NULL"):
return pa.null()
elif t in ("BOOLEAN", "BOOL"):
return pa.bool_()
elif t in ("TINYINT", "INT8"):
return pa.int8()
elif t in ("SMALLINT", "INT16"):
return pa.int16()
elif t in ("INT", "INTEGER", "INT32"):
return pa.int32()
elif t in ("BIGINT", "INT64"):
return pa.int64()
elif t in ("UINT8"):
return pa.uint8()
elif t in ("UINT16"):
return pa.uint16()
elif t in ("UINT32"):
return pa.uint32()
elif t in ("UINT64"):
return pa.uint64()
elif t in ("FLOAT32", "REAL"):
return pa.float32()
elif t in ("FLOAT64", "DOUBLE PRECISION"):
return pa.float64()
elif t.startswith("DECIMAL") or t.startswith("NUMERIC"):
if t == "DECIMAL" or t == "NUMERIC":
return DecimalType()
rest = t[8:-1] if "," in rest:
precision, scale = rest.split(",")
return pa.decimal128(int(precision), int(scale))
else:
return pa.decimal128(int(rest), 0)
elif t in ("DATE32", "DATE"):
return pa.date32()
elif t in ("TIME64", "TIME", "TIME WITHOUT TIME ZONE"):
return pa.time64("us")
elif t in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"):
return pa.timestamp("us")
elif t.startswith("INTERVAL"):
return pa.month_day_nano_interval()
elif t in ("STRING", "VARCHAR"):
return pa.string()
elif t in ("LARGE_STRING"):
return pa.large_string()
elif t in ("JSON", "JSONB"):
return JsonType()
elif t in ("BINARY", "BYTEA"):
return pa.binary()
elif t in ("LARGE_BINARY"):
return pa.large_binary()
raise ValueError(f"Unsupported type: {t}")