arrow-udf-flight 0.4.2

Client for remote Arrow UDFs.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
# Copyright 2024 RisingWave Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import *
import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet
import inspect
import traceback
import json
from concurrent.futures import ThreadPoolExecutor
import concurrent
from decimal import Decimal
import signal


class UserDefinedFunction:
    """
    Base interface for user-defined function.
    """

    _name: str
    _input_schema: pa.Schema
    _result_schema: pa.Schema
    _executor: Optional[ThreadPoolExecutor] = None

    def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
        """
        Apply the function on a batch of inputs.
        """
        return iter([])


class ScalarFunction(UserDefinedFunction):
    """
    Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one,
    or multiple scalar values to a new scalar value.
    """

    _batch: bool

    def __init__(self, *args, **kwargs):
        self._batch = kwargs.pop("batch", False)
        io_threads = kwargs.pop("io_threads", None) or 1
        if not self._batch and io_threads > 1:
            self._executor = ThreadPoolExecutor(max_workers=io_threads)
        super().__init__(*args, **kwargs)

    def eval(self, *args) -> Any:
        """
        Method which defines the logic of the scalar function.
        """
        pass

    def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
        # parse value from json string for jsonb columns
        inputs = [[v.as_py() for v in array] for array in batch]

        if self._batch:
            # evaluate the function on the entire batch
            results = self._func(*inputs)
        else:
            # evaluate the function row by row
            if self._executor:
                # run in executor concurrently
                results = list(
                    self._executor.map(
                        lambda args: self._func(*args),  # manual `starmap`
                        (
                            # converts column-based inputs to rows
                            [col[i] for col in inputs]
                            for i in range(batch.num_rows)
                        ),
                    )
                )
            else:
                # run sequentially
                results = [
                    self.eval(*[col[i] for col in inputs])
                    for i in range(batch.num_rows)
                ]

        array = _to_arrow_array(results, self._result_schema.types[0])

        yield pa.RecordBatch.from_arrays([array], schema=self._result_schema)


def _to_arrow_array(column: List, type: pa.DataType) -> pa.Array:
    """Return a function to convert a list of python objects to an arrow array."""
    if pa.types.is_list(type):
        # flatten the list of lists
        offsets = [0]
        values = []
        mask = []
        for array in column:
            if array is not None:
                values.extend(array)
            offsets.append(len(values))
            mask.append(array is None)
        offsets = pa.array(offsets, type=pa.int32())
        values = _to_arrow_array(values, type.value_type)
        mask = pa.array(mask, type=pa.bool_())
        return pa.ListArray.from_arrays(offsets, values, mask=mask)

    if pa.types.is_struct(type):
        arrays = [
            _to_arrow_array(
                [v.get(field.name) if v is not None else None for v in column],
                field.type,
            )
            for field in type
        ]
        mask = pa.array([v is None for v in column], type=pa.bool_())
        return pa.StructArray.from_arrays(arrays, fields=type, mask=mask)

    if type.equals(JsonType()):
        s = pa.array(
            [json.dumps(v) if v is not None else None for v in column], type=pa.string()
        )
        return pa.ExtensionArray.from_storage(JsonType(), s)

    if type.equals(DecimalType()):
        s = pa.array(
            [_decimal_to_str(v) if v is not None else None for v in column],
            type=pa.string(),
        )
        return pa.ExtensionArray.from_storage(DecimalType(), s)

    return pa.array(column, type=type)


def _decimal_to_str(v: Decimal) -> str:
    if not isinstance(v, Decimal):
        raise ValueError(f"Expected Decimal, got {v}")
    # use `f` format to avoid scientific notation, e.g. `1e10`
    return format(v, "f")


class TableFunction(UserDefinedFunction):
    """
    Base interface for user-defined table function. A user-defined table functions maps zero, one,
    or multiple scalar values to a new table value.
    """

    BATCH_SIZE = 1024

    def eval(self, *args) -> Iterator:
        """
        Method which defines the logic of the table function.
        """
        yield

    def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
        class RecordBatchBuilder:
            """A utility class for constructing Arrow RecordBatch by row."""

            schema: pa.Schema
            columns: List[List]

            def __init__(self, schema: pa.Schema):
                self.schema = schema
                self.columns = [[] for _ in self.schema.types]

            def len(self) -> int:
                """Returns the number of rows in the RecordBatch being built."""
                return len(self.columns[0])

            def append(self, index: int, value: Any):
                """Appends a new row to the RecordBatch being built."""
                self.columns[0].append(index)
                self.columns[1].append(value)

            def build(self) -> pa.RecordBatch:
                """Builds the RecordBatch from the accumulated data and clears the state."""
                # Convert the columns to arrow arrays
                arrays = [
                    pa.array(col, type)
                    for col, type in zip(self.columns, self.schema.types)
                ]
                # Reset columns
                self.columns = [[] for _ in self.schema.types]
                return pa.RecordBatch.from_arrays(arrays, schema=self.schema)

        builder = RecordBatchBuilder(self._result_schema)

        # Iterate through rows in the input RecordBatch
        for row_index in range(batch.num_rows):
            row = tuple(column[row_index].as_py() for column in batch)
            for result in self.eval(*row):
                builder.append(row_index, result)
                if builder.len() == self.BATCH_SIZE:
                    yield builder.build()
        if builder.len() != 0:
            yield builder.build()


class UserDefinedScalarFunctionWrapper(ScalarFunction):
    """
    Base Wrapper for Python user-defined scalar function.
    """

    _func: Callable

    def __init__(self, func, input_types, result_type, name=None, **kwargs):
        self._func = func
        self._name = name or (
            func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
        )
        self._input_schema = pa.schema(
            zip(
                inspect.getfullargspec(func)[0],
                [_to_data_type(t) for t in _to_list(input_types)],
            )
        )
        self._result_schema = pa.schema([(self._name, _to_data_type(result_type))])

        super().__init__(**kwargs)

    def __call__(self, *args):
        return self._func(*args)

    def eval(self, *args):
        return self._func(*args)


class UserDefinedTableFunctionWrapper(TableFunction):
    """
    Base Wrapper for Python user-defined table function.
    """

    _func: Callable

    def __init__(self, func, input_types, result_types, name=None):
        self._func = func
        self._name = name or (
            func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
        )
        self._input_schema = pa.schema(
            zip(
                inspect.getfullargspec(func)[0],
                [_to_data_type(t) for t in _to_list(input_types)],
            )
        )
        self._result_schema = pa.schema(
            [
                ("row", pa.int32()),
                (
                    self._name,
                    (
                        pa.struct([("", _to_data_type(t)) for t in result_types])
                        if isinstance(result_types, list)
                        else _to_data_type(result_types)
                    ),
                ),
            ]
        )

    def __call__(self, *args):
        return self._func(*args)

    def eval(self, *args):
        return self._func(*args)


def _to_list(x):
    if isinstance(x, list):
        return x
    else:
        return [x]


def udf(
    input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
    result_type: Union[str, pa.DataType],
    name: Optional[str] = None,
    io_threads: Optional[int] = None,
    batch: bool = False,
) -> Callable:
    """
    Annotation for creating a user-defined scalar function.

    Parameters:
    - input_types: A list of strings or Arrow data types that specifies the input data types.
    - result_type: A string or an Arrow data type that specifies the return value type.
    - name: An optional string specifying the function name. If not provided, the original name will be used.
    - io_threads: Number of I/O threads used per data chunk for I/O bound functions.
    - batch: Whether the function accepts and returns a batch of data. When this is True, `io_threads` will take no effect.

    Example:
    ```
    @udf(input_types=['INT', 'INT'], result_type='INT')
    def gcd(x, y):
        while y != 0:
            (x, y) = (y, x % y)
        return x
    ```

    I/O bound Example:
    ```
    @udf(input_types=['INT'], result_type='INT', io_threads=64)
    def external_api(x):
        response = requests.get(my_endpoint + '?param=' + x)
        return response["data"]
    ```

    Batched Example:
    ```
    @udf(input_types=['VARCHAR'], result_type='REAL[]', batch=True)
    def external_api(texts: List[str]) -> List[List[float]]:
        response = requests.post(my_endpoint, json={"inputs": texts})
        return response["data"]
    ```
    """

    return lambda f: UserDefinedScalarFunctionWrapper(
        f, input_types, result_type, name, io_threads=io_threads, batch=batch
    )


def udtf(
    input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
    result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
    name: Optional[str] = None,
) -> Callable:
    """
    Annotation for creating a user-defined table function.

    Parameters:
    - input_types: A list of strings or Arrow data types that specifies the input data types.
    - result_types A list of strings or Arrow data types that specifies the return value types.
    - name: An optional string specifying the function name. If not provided, the original name will be used.

    Example:
    ```
    @udtf(input_types='INT', result_types='INT')
    def series(n):
        for i in range(n):
            yield i
    ```
    """

    return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name)


class UdfServer(pa.flight.FlightServerBase):
    """
    A server that provides user-defined functions to clients.

    Example:
    ```
    server = UdfServer(location="0.0.0.0:8815")
    server.add_function(my_udf)
    server.serve()
    ```
    """

    # UDF server based on Apache Arrow Flight protocol.
    # Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight

    _location: str
    _functions: Dict[str, UserDefinedFunction]

    def __init__(self, location="0.0.0.0:8815", **kwargs):
        super(UdfServer, self).__init__("grpc://" + location, **kwargs)
        self._location = location
        self._functions = {}

    def get_flight_info(self, context, descriptor):
        """Return the result schema of a function."""
        udf = self._functions[descriptor.path[0].decode("utf-8")]
        return self._make_flight_info(udf)

    def _make_flight_info(self, udf: UserDefinedFunction) -> pa.flight.FlightInfo:
        """Return the flight info of a function."""
        # return the concatenation of input and output schema
        full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema))
        # we use `total_records` to indicate the number of input arguments
        return pa.flight.FlightInfo(
            schema=full_schema,
            descriptor=pa.flight.FlightDescriptor.for_path(udf._name),
            endpoints=[],
            total_records=len(udf._input_schema),
            total_bytes=0,
        )

    def list_flights(self, context, criteria):
        """Return the list of functions."""
        return [self._make_flight_info(udf) for udf in self._functions.values()]

    def add_function(self, udf: UserDefinedFunction):
        """Add a function to the server."""
        name = udf._name
        if name in self._functions:
            raise ValueError("Function already exists: " + name)
        print(f"added function: {name}")
        self._functions[name] = udf

    def do_exchange(self, context, descriptor, reader, writer):
        """Call a function from the client."""
        udf = self._functions[descriptor.path[0].decode("utf-8")]
        writer.begin(udf._result_schema)
        try:
            for batch in reader:
                # print(pa.Table.from_batches([batch.data]))
                for output_batch in udf.eval_batch(batch.data):
                    writer.write_batch(output_batch)
        except Exception as e:
            print(traceback.print_exc())
            raise e

    def do_action(self, context, action):
        if action.type == "protocol_version":
            yield b"\x02"
        else:
            raise NotImplementedError

    def serve(self):
        """
        Block until the server shuts down.

        This method only returns if shutdown() is called or a signal (SIGINT, SIGTERM) received.
        """
        print(f"listening on {self._location}")
        signal.signal(signal.SIGTERM, lambda s, f: self.shutdown())
        super(UdfServer, self).serve()


class JsonScalar(pa.ExtensionScalar):
    def as_py(self):
        return json.loads(self.value.as_py()) if self.value is not None else None


class JsonType(pa.ExtensionType):
    def __init__(self):
        super().__init__(pa.string(), "arrowudf.json")

    def __arrow_ext_serialize__(self):
        # since we don't have a parameterized type, we don't need extra
        # metadata to be deserialized
        return b""

    @classmethod
    def __arrow_ext_deserialize__(self, storage_type, serialized):
        # return an instance of this subclass given the serialized
        # metadata.
        return JsonType()

    def __arrow_ext_scalar_class__(self):
        return JsonScalar


class DecimalScalar(pa.ExtensionScalar):
    def as_py(self):
        return Decimal(self.value.as_py()) if self.value is not None else None


class DecimalType(pa.ExtensionType):
    def __init__(self):
        super().__init__(pa.string(), "arrowudf.decimal")

    def __arrow_ext_serialize__(self):
        # since we don't have a parameterized type, we don't need extra
        # metadata to be deserialized
        return b""

    @classmethod
    def __arrow_ext_deserialize__(self, storage_type, serialized):
        # return an instance of this subclass given the serialized
        # metadata.
        return DecimalType()

    def __arrow_ext_scalar_class__(self):
        return DecimalScalar


pa.register_extension_type(JsonType())
pa.register_extension_type(DecimalType())


def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType:
    """
    Convert a SQL data type string or `pyarrow.DataType` to `pyarrow.DataType`.
    """
    if isinstance(t, str):
        return _string_to_data_type(t)
    else:
        return t


def _string_to_data_type(type: str):
    """
    Convert a SQL data type string to `pyarrow.DataType`.
    """
    t = type.upper()
    if t.endswith("[]"):
        return pa.list_(_string_to_data_type(type[:-2]))
    elif t.startswith("STRUCT"):
        # extract 'STRUCT<a:INT, b:VARCHAR, c:STRUCT<d:INT>, ...>'
        type_list = type[7:-1]  # strip "STRUCT<>"
        fields = []
        start = 0
        depth = 0
        for i, c in enumerate(type_list):
            if c == "<":
                depth += 1
            elif c == ">":
                depth -= 1
            elif c == "," and depth == 0:
                name, t = type_list[start:i].split(":", maxsplit=1)
                name = name.strip()
                t = t.strip()
                fields.append(pa.field(name, _string_to_data_type(t)))
                start = i + 1
        if ":" in type_list[start:].strip():
            name, t = type_list[start:].split(":", maxsplit=1)
            name = name.strip()
            t = t.strip()
            fields.append(pa.field(name, _string_to_data_type(t)))
        return pa.struct(fields)
    elif t in ("NULL"):
        return pa.null()
    elif t in ("BOOLEAN", "BOOL"):
        return pa.bool_()
    elif t in ("TINYINT", "INT8"):
        return pa.int8()
    elif t in ("SMALLINT", "INT16"):
        return pa.int16()
    elif t in ("INT", "INTEGER", "INT32"):
        return pa.int32()
    elif t in ("BIGINT", "INT64"):
        return pa.int64()
    elif t in ("UINT8"):
        return pa.uint8()
    elif t in ("UINT16"):
        return pa.uint16()
    elif t in ("UINT32"):
        return pa.uint32()
    elif t in ("UINT64"):
        return pa.uint64()
    elif t in ("FLOAT32", "REAL"):
        return pa.float32()
    elif t in ("FLOAT64", "DOUBLE PRECISION"):
        return pa.float64()
    elif t.startswith("DECIMAL") or t.startswith("NUMERIC"):
        if t == "DECIMAL" or t == "NUMERIC":
            return DecimalType()
        rest = t[8:-1]  # remove "DECIMAL(" and ")"
        if "," in rest:
            precision, scale = rest.split(",")
            return pa.decimal128(int(precision), int(scale))
        else:
            return pa.decimal128(int(rest), 0)
    elif t in ("DATE32", "DATE"):
        return pa.date32()
    elif t in ("TIME64", "TIME", "TIME WITHOUT TIME ZONE"):
        return pa.time64("us")
    elif t in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"):
        return pa.timestamp("us")
    elif t.startswith("INTERVAL"):
        return pa.month_day_nano_interval()
    elif t in ("STRING", "VARCHAR"):
        return pa.string()
    elif t in ("LARGE_STRING"):
        return pa.large_string()
    elif t in ("JSON", "JSONB"):
        return JsonType()
    elif t in ("BINARY", "BYTEA"):
        return pa.binary()
    elif t in ("LARGE_BINARY"):
        return pa.large_binary()

    raise ValueError(f"Unsupported type: {t}")