arrow_udf_macros/
lib.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro::TokenStream;
16use proc_macro2::TokenStream as TokenStream2;
17use syn::{Error, Result};
18
19mod gen;
20mod parse;
21mod struct_type;
22mod types;
23mod utils;
24
25/// Derive `StructType` for user defined struct.
26///
27/// Structs that implement `StructType` can be used as Arrow struct types.
28///
29/// # Examples
30///
31/// ```ignore
32/// #[derive(StructType)]
33/// struct KeyValue<'a> {
34///     key: &'a str,
35///     value: &'a str,
36/// }
37/// ```
38///
39/// ```ignore
40/// #[function("split_kv(string) -> struct KeyValue")]
41/// fn split_kv(kv: &str) -> Option<KeyValue<'_>> {
42///     let (key, value) = kv.split_once('=')?;
43///     Some(KeyValue { key, value })
44/// }
45/// ```
46#[proc_macro_derive(StructType)]
47pub fn struct_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
48    match struct_type::gen(tokens.into()) {
49        Ok(output) => output.into(),
50        Err(err) => err.to_compile_error().into(),
51    }
52}
53
54/// Defining a function on Arrow arrays.
55///
56/// # Table of Contents
57///
58/// - [SQL Function Signature](#sql-function-signature)
59///     - [Multiple Function Definitions](#multiple-function-definitions)
60/// - [Rust Function Signature](#rust-function-signature)
61///     - [Nullable Arguments](#nullable-arguments)
62///     - [Return Value](#return-value)
63///     - [Optimization](#optimization)
64///     - [Functions Returning Strings](#functions-returning-strings)
65/// - [Table Function](#table-function)
66/// - [Registration and Invocation](#registration-and-invocation)
67/// - [Appendix: Type Matrix](#appendix-type-matrix)
68///
69/// The following example demonstrates a simple usage:
70///
71/// ```ignore
72/// #[function("add(int, int) -> int")]
73/// fn add(x: i32, y: i32) -> i32 {
74///     x + y
75/// }
76/// ```
77///
78/// # SQL Function Signature
79///
80/// Each function must have a signature, specified in the `function("...")` part of the macro
81/// invocation. The signature follows this pattern:
82///
83/// ```text
84/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
85/// ```
86///
87/// Where `name` is the function name.
88///
89/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
90/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
91/// explained below. If the function is variadic, the last argument can be denoted as `...`.
92///
93/// When `setof` appears before the return type, this indicates that the function is a set-returning
94/// function (table function), meaning it can return multiple values instead of just one. For more
95/// details, see the section on table functions.
96///
97/// If no return type is specified, the function returns `null`.
98///
99/// ## Multiple Function Definitions
100///
101/// Multiple `#[function]` macros can be applied to a single generic Rust function to define
102/// multiple SQL functions of different types. For example:
103///
104/// ```ignore
105/// #[function("add(int16, int16) -> int16")]
106/// #[function("add(int32, int32) -> int32")]
107/// #[function("add(int64, int64) -> int64")]
108/// fn add<T: Add>(x: T, y: T) -> T {
109///     x + y
110/// }
111/// ```
112///
113/// # Rust Function Signature
114///
115/// The `#[function]` macro can handle various types of Rust functions.
116/// Each argument corresponds to the *Rust type* `T` in the [type matrix].
117/// The return value type can be any type that implements `AsRef<T>`.
118///
119/// ## Nullable Arguments
120///
121/// The functions above will only be called when all arguments are not null. If null arguments need
122/// to be considered, the `Option` type can be used:
123///
124/// ```ignore
125/// #[function("add(int, int) -> int")]
126/// fn add(x: Option<i32>, y: i32) -> i32 {...}
127/// ```
128///
129/// ## Return Value
130///
131/// Similarly, the return value type can be one of the following:
132///
133/// - `T`: Indicates that a non-null value is always returned, and errors will not occur.
134/// - `Option<T>`: Indicates that a null value may be returned, but errors will not occur.
135/// - `Result<T>`: Indicates that an error may occur, but a null value will not be returned.
136/// - `Result<Option<T>>`: Indicates that a null value may be returned, and an error may also occur.
137///
138/// ## Optimization
139///
140/// When all input and output types of the function are *primitive type* (int2, int4, int8, float4, float8)
141/// and do not contain any Option or Result, the `#[function]` macro will automatically
142/// generate SIMD vectorized execution code.
143///
144/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
145///
146/// ## Functions Returning Strings
147///
148/// For functions that return string types, you can also use the writer style function signature to
149/// avoid memory copying and dynamic memory allocation:
150///
151/// ```ignore
152/// #[function("trim(string) -> string")]
153/// fn trim(s: &str, writer: &mut impl Write) {
154///     writer.write_str(s.trim()).unwrap();
155/// }
156/// ```
157///
158/// If errors may be returned, then the return value should be `Result<()>`:
159///
160/// ```ignore
161/// #[function("trim(string) -> string")]
162/// fn trim(s: &str, writer: &mut impl Write) -> Result<()> {
163///     writer.write_str(s.trim()).unwrap();
164///     Ok(())
165/// }
166/// ```
167///
168/// If null values may be returned, then the return value should be `Option<()>`:
169///
170/// ```ignore
171/// #[function("trim(string) -> string")]
172/// fn trim(s: &str, writer: &mut impl Write) -> Option<()> {
173///     if s.is_empty() {
174///         None
175///     } else {
176///         writer.write_str(s.trim()).unwrap();
177///         Some(())
178///     }
179/// }
180/// ```
181///
182/// # Table Function
183///
184/// A table function is a special kind of function that can return multiple values instead of just
185/// one. Its function signature must include the `setof` keyword, and the Rust function should
186/// return an iterator of the form `impl Iterator<Item = T>` or its derived types.
187///
188/// For example:
189/// ```ignore
190/// #[function("generate_series(int32, int32) -> setof int32")]
191/// fn generate_series(start: i32, stop: i32) -> impl Iterator<Item = i32> {
192///     start..=stop
193/// }
194/// ```
195///
196/// Likewise, the return value `Iterator` can include `Option` or `Result` either internally or
197/// externally. For instance:
198///
199/// - `impl Iterator<Item = Result<T>>`
200/// - `Result<impl Iterator<Item = T>>`
201/// - `Result<impl Iterator<Item = Result<Option<T>>>>`
202///
203/// # Registration and Invocation
204///
205/// Every function defined by `#[function]` is automatically registered in the global function registry.
206///
207/// You can lookup the function by name and types:
208///
209/// ```ignore
210/// use arrow_udf::sig::REGISTRY;
211/// use arrow_schema::DataType::Int32;
212///
213/// let sig = REGISTRY.get("add", &[Int32, Int32], &Int32).unwrap();
214/// ```
215///
216/// # Appendix: Type Matrix
217///
218/// ## Base Types
219///
220/// | Arrow data type      | Aliases            | Rust type as argument          | Rust type as return value      |
221/// | -------------------- | ------------------ | ------------------------------ | ------------------------------ |
222/// | `boolean`            | `bool`             | `bool`                         | `bool`                         |
223/// | `int8`               |                    | `i8`                           | `i8`                           |
224/// | `int16`              | `smallint`         | `i16`                          | `i16`                          |
225/// | `int32`              | `int`              | `i32`                          | `i32`                          |
226/// | `int64`              | `bigint`           | `i64`                          | `i64`                          |
227/// | `float32`            | `real`             | `f32`                          | `f32`                          |
228/// | `float32`            | `double precision` | `f64`                          | `f64`                          |
229/// | `date32`             | `date`             | [`chrono::NaiveDate`]          | [`chrono::NaiveDate`]          |
230/// | `time64`             | `time`             | [`chrono::NaiveTime`]          | [`chrono::NaiveTime`]          |
231/// | `timestamp`          |                    | [`chrono::NaiveDateTime`]      | [`chrono::NaiveDateTime`]      |
232/// | `timestamptz`        |                    | not supported yet              | not supported yet              |
233/// | `interval`           |                    | [`arrow_udf::types::Interval`] | [`arrow_udf::types::Interval`] |
234/// | `string`             | `varchar`          | `&str`                         | `impl AsRef<str>`, e.g. `String`, `Box<str>`, `&str`     |
235/// | `binary`             | `bytea`            | `&[u8]`                        | `impl AsRef<[u8]>`, e.g. `Vec<u8>`, `Box<[u8]>`, `&[u8]` |
236///
237/// ## Extension Types
238///
239/// We also support the following extension types that are not part of the Arrow data types:
240///
241/// | Data type   | Metadata            | Rust type as argument          | Rust type as return value      |
242/// | ----------- | ------------------- | ------------------------------ | ------------------------------ |
243/// | `decimal`   | `arrowudf.decimal`  | [`rust_decimal::Decimal`]      | [`rust_decimal::Decimal`]      |
244/// | `json`      | `arrowudf.json`     | [`serde_json::Value`]          | [`serde_json::Value`]          |
245///
246/// ## Array Types
247///
248/// | SQL type              | Rust type as argument     | Rust type as return value      |
249/// | --------------------  | ------------------------- | ------------------------------ |
250/// | `int8[]`              | `&[i8]`                   | `impl Iterator<Item = i8>`     |
251/// | `int16[]`             | `&[i16]`                  | `impl Iterator<Item = i16>`    |
252/// | `int32[]`             | `&[i32]`                  | `impl Iterator<Item = i32>`    |
253/// | `int64[]`             | `&[i64]`                  | `impl Iterator<Item = i64>`    |
254/// | `float32[]`           | `&[f32]`                  | `impl Iterator<Item = f32>`    |
255/// | `float64[]`           | `&[f64]`                  | `impl Iterator<Item = f64>`    |
256/// | `string[]`            | [`&StringArray`]          | `impl Iterator<Item = &str>`   |
257/// | `binary[]`            | [`&BinaryArray`]          | `impl Iterator<Item = &[u8]>`  |
258/// | `largestring[]`       | [`&LargeStringArray`]     | `impl Iterator<Item = &str>`   |
259/// | `largebinary[]`       | [`&LargeBinaryArray`]     | `impl Iterator<Item = &[u8]>`  |
260/// | `others[]`            | not supported yet         | not supported yet              |
261///
262/// ## Composite Types
263///
264/// | SQL type              | Rust type as argument     | Rust type as return value      |
265/// | --------------------- | ------------------------- | ------------------------------ |
266/// | `struct<..>`          | `UserDefinedStruct`       | `UserDefinedStruct`            |
267///
268/// [type matrix]: #appendix-type-matrix
269/// [`rust_decimal::Decimal`]: https://docs.rs/rust_decimal/1.33.1/rust_decimal/struct.Decimal.html
270/// [`chrono::NaiveDate`]: https://docs.rs/chrono/0.4.31/chrono/naive/struct.NaiveDate.html
271/// [`chrono::NaiveTime`]: https://docs.rs/chrono/0.4.31/chrono/naive/struct.NaiveTime.html
272/// [`chrono::NaiveDateTime`]: https://docs.rs/chrono/0.4.31/chrono/naive/struct.NaiveDateTime.html
273/// [`arrow_udf::types::Interval`]: https://docs.rs/arrow_udf/0.1.0/arrow_udf/types/struct.Interval.html
274/// [`serde_json::Value`]: https://docs.rs/serde_json/1.0.108/serde_json/enum.Value.html
275/// [`&StringArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.StringArray.html
276/// [`&BinaryArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.BinaryArray.html
277/// [`&LargeStringArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.LargeStringArray.html
278/// [`&LargeBinaryArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.LargeBinaryArray.html
279#[proc_macro_attribute]
280pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
281    fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
282        let fn_attr: FunctionAttr = syn::parse(attr)?;
283        let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
284
285        let mut tokens: TokenStream2 = item.into();
286        for attr in fn_attr.expand() {
287            tokens.extend(attr.generate_function_descriptor(&user_fn)?);
288        }
289        Ok(tokens)
290    }
291    match inner(attr, item) {
292        Ok(tokens) => tokens.into(),
293        Err(e) => e.to_compile_error().into(),
294    }
295}
296
297#[derive(Debug, Clone, Default)]
298struct FunctionAttr {
299    /// Function name
300    name: String,
301    /// Input argument types
302    args: Vec<String>,
303    /// Return type
304    ret: String,
305    /// Whether it is a table function
306    is_table_function: bool,
307    /// Whether it is an append-only aggregate function
308    append_only: bool,
309    /// Optional function for batch evaluation.
310    batch_fn: Option<String>,
311    /// State type for aggregate function.
312    /// If not specified, it will be the same as return type.
313    state: Option<String>,
314    /// Initial state value for aggregate function.
315    /// If not specified, it will be NULL.
316    init_state: Option<String>,
317    /// Type inference function.
318    type_infer: Option<String>,
319    /// Generic type.
320    generic: Option<String>,
321    /// Whether the function is volatile.
322    volatile: bool,
323    /// Generated batch function name.
324    /// If not specified, the macro will not generate batch function.
325    output: Option<String>,
326    /// If specified, the macro will generate DuckDB struct with this name.
327    duckdb: Option<String>,
328    /// Customized function visibility.
329    visibility: Option<String>,
330}
331
332/// Attributes from function signature `fn(..)`
333#[derive(Debug, Clone)]
334#[allow(dead_code)]
335struct UserFunctionAttr {
336    /// Function name
337    name: String,
338    /// Whether the function is async.
339    async_: bool,
340    /// Whether contains argument `&Context`.
341    context: bool,
342    /// Whether contains argument `&mut impl Write`.
343    write: bool,
344    /// Whether the last argument type is `retract: bool`.
345    retract: bool,
346    /// Whether each argument type is `Option<T>`.
347    args_option: Vec<bool>,
348    /// If the first argument type is `&mut T`, then `Some(T)`.
349    first_mut_ref_arg: Option<String>,
350    /// The return type kind.
351    return_type_kind: ReturnTypeKind,
352    /// The kind of inner type `T` in `impl Iterator<Item = T>`
353    iterator_item_kind: Option<ReturnTypeKind>,
354    /// The core return type without `Option` or `Result`.
355    core_return_type: String,
356    /// The number of generic types.
357    generic: usize,
358    /// The span of return type.
359    return_type_span: proc_macro2::Span,
360}
361
362#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
363enum ReturnTypeKind {
364    T,
365    Option,
366    Result,
367    ResultOption,
368}
369
370impl ReturnTypeKind {
371    /// Returns true if the type is `Result<..>`.
372    const fn is_result(&self) -> bool {
373        matches!(self, ReturnTypeKind::Result | ReturnTypeKind::ResultOption)
374    }
375}
376
377impl FunctionAttr {
378    /// Return a unique name that can be used as an identifier.
379    fn ident_name(&self) -> String {
380        format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
381            .replace("[]", "array")
382            .replace("...", "variadic")
383            .replace(['<', ' ', ',', ':'], "_")
384            .replace('>', "")
385            .replace("__", "_")
386    }
387
388    /// Return a unique signature of the function.
389    fn normalize_signature(&self) -> String {
390        format!(
391            "{}({}){}{}",
392            self.name,
393            self.args.join(","),
394            if self.is_table_function { "->>" } else { "->" },
395            self.ret
396        )
397    }
398}
399
400impl UserFunctionAttr {
401    /// Returns true if the function is like `fn(T1, T2, .., Tn) -> T`.
402    fn is_pure(&self) -> bool {
403        !self.async_
404            && !self.write
405            && !self.context
406            && self.args_option.iter().all(|b| !b)
407            && self.return_type_kind == ReturnTypeKind::T
408    }
409
410    /// Returns true if the function may return error.
411    fn has_error(&self) -> bool {
412        self.return_type_kind.is_result()
413            || matches!(&self.iterator_item_kind, Some(k) if k.is_result())
414    }
415}