arrow_udf_macros/lib.rs
1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro::TokenStream;
16use proc_macro2::TokenStream as TokenStream2;
17use syn::{Error, Result};
18
19mod gen;
20mod parse;
21mod struct_type;
22mod types;
23mod utils;
24
25/// Derive `StructType` for user defined struct.
26///
27/// Structs that implement `StructType` can be used as Arrow struct types.
28///
29/// # Examples
30///
31/// ```ignore
32/// #[derive(StructType)]
33/// struct KeyValue<'a> {
34/// key: &'a str,
35/// value: &'a str,
36/// }
37/// ```
38///
39/// ```ignore
40/// #[function("split_kv(string) -> struct KeyValue")]
41/// fn split_kv(kv: &str) -> Option<KeyValue<'_>> {
42/// let (key, value) = kv.split_once('=')?;
43/// Some(KeyValue { key, value })
44/// }
45/// ```
46#[proc_macro_derive(StructType)]
47pub fn struct_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
48 match struct_type::gen(tokens.into()) {
49 Ok(output) => output.into(),
50 Err(err) => err.to_compile_error().into(),
51 }
52}
53
54/// Defining a function on Arrow arrays.
55///
56/// # Table of Contents
57///
58/// - [SQL Function Signature](#sql-function-signature)
59/// - [Multiple Function Definitions](#multiple-function-definitions)
60/// - [Rust Function Signature](#rust-function-signature)
61/// - [Nullable Arguments](#nullable-arguments)
62/// - [Return Value](#return-value)
63/// - [Optimization](#optimization)
64/// - [Functions Returning Strings](#functions-returning-strings)
65/// - [Table Function](#table-function)
66/// - [Registration and Invocation](#registration-and-invocation)
67/// - [Appendix: Type Matrix](#appendix-type-matrix)
68///
69/// The following example demonstrates a simple usage:
70///
71/// ```ignore
72/// #[function("add(int, int) -> int")]
73/// fn add(x: i32, y: i32) -> i32 {
74/// x + y
75/// }
76/// ```
77///
78/// # SQL Function Signature
79///
80/// Each function must have a signature, specified in the `function("...")` part of the macro
81/// invocation. The signature follows this pattern:
82///
83/// ```text
84/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
85/// ```
86///
87/// Where `name` is the function name.
88///
89/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
90/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
91/// explained below. If the function is variadic, the last argument can be denoted as `...`.
92///
93/// When `setof` appears before the return type, this indicates that the function is a set-returning
94/// function (table function), meaning it can return multiple values instead of just one. For more
95/// details, see the section on table functions.
96///
97/// If no return type is specified, the function returns `null`.
98///
99/// ## Multiple Function Definitions
100///
101/// Multiple `#[function]` macros can be applied to a single generic Rust function to define
102/// multiple SQL functions of different types. For example:
103///
104/// ```ignore
105/// #[function("add(int16, int16) -> int16")]
106/// #[function("add(int32, int32) -> int32")]
107/// #[function("add(int64, int64) -> int64")]
108/// fn add<T: Add>(x: T, y: T) -> T {
109/// x + y
110/// }
111/// ```
112///
113/// # Rust Function Signature
114///
115/// The `#[function]` macro can handle various types of Rust functions.
116/// Each argument corresponds to the *Rust type* `T` in the [type matrix].
117/// The return value type can be any type that implements `AsRef<T>`.
118///
119/// ## Nullable Arguments
120///
121/// The functions above will only be called when all arguments are not null. If null arguments need
122/// to be considered, the `Option` type can be used:
123///
124/// ```ignore
125/// #[function("add(int, int) -> int")]
126/// fn add(x: Option<i32>, y: i32) -> i32 {...}
127/// ```
128///
129/// ## Return Value
130///
131/// Similarly, the return value type can be one of the following:
132///
133/// - `T`: Indicates that a non-null value is always returned, and errors will not occur.
134/// - `Option<T>`: Indicates that a null value may be returned, but errors will not occur.
135/// - `Result<T>`: Indicates that an error may occur, but a null value will not be returned.
136/// - `Result<Option<T>>`: Indicates that a null value may be returned, and an error may also occur.
137///
138/// ## Optimization
139///
140/// When all input and output types of the function are *primitive type* (int2, int4, int8, float4, float8)
141/// and do not contain any Option or Result, the `#[function]` macro will automatically
142/// generate SIMD vectorized execution code.
143///
144/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
145///
146/// ## Functions Returning Strings
147///
148/// For functions that return string types, you can also use the writer style function signature to
149/// avoid memory copying and dynamic memory allocation:
150///
151/// ```ignore
152/// #[function("trim(string) -> string")]
153/// fn trim(s: &str, writer: &mut impl Write) {
154/// writer.write_str(s.trim()).unwrap();
155/// }
156/// ```
157///
158/// If errors may be returned, then the return value should be `Result<()>`:
159///
160/// ```ignore
161/// #[function("trim(string) -> string")]
162/// fn trim(s: &str, writer: &mut impl Write) -> Result<()> {
163/// writer.write_str(s.trim()).unwrap();
164/// Ok(())
165/// }
166/// ```
167///
168/// If null values may be returned, then the return value should be `Option<()>`:
169///
170/// ```ignore
171/// #[function("trim(string) -> string")]
172/// fn trim(s: &str, writer: &mut impl Write) -> Option<()> {
173/// if s.is_empty() {
174/// None
175/// } else {
176/// writer.write_str(s.trim()).unwrap();
177/// Some(())
178/// }
179/// }
180/// ```
181///
182/// # Table Function
183///
184/// A table function is a special kind of function that can return multiple values instead of just
185/// one. Its function signature must include the `setof` keyword, and the Rust function should
186/// return an iterator of the form `impl Iterator<Item = T>` or its derived types.
187///
188/// For example:
189/// ```ignore
190/// #[function("generate_series(int32, int32) -> setof int32")]
191/// fn generate_series(start: i32, stop: i32) -> impl Iterator<Item = i32> {
192/// start..=stop
193/// }
194/// ```
195///
196/// Likewise, the return value `Iterator` can include `Option` or `Result` either internally or
197/// externally. For instance:
198///
199/// - `impl Iterator<Item = Result<T>>`
200/// - `Result<impl Iterator<Item = T>>`
201/// - `Result<impl Iterator<Item = Result<Option<T>>>>`
202///
203/// # Registration and Invocation
204///
205/// Every function defined by `#[function]` is automatically registered in the global function registry.
206///
207/// You can lookup the function by name and types:
208///
209/// ```ignore
210/// use arrow_udf::sig::REGISTRY;
211/// use arrow_schema::DataType::Int32;
212///
213/// let sig = REGISTRY.get("add", &[Int32, Int32], &Int32).unwrap();
214/// ```
215///
216/// # Appendix: Type Matrix
217///
218/// ## Base Types
219///
220/// | Arrow data type | Aliases | Rust type as argument | Rust type as return value |
221/// | -------------------- | ------------------ | ------------------------------ | ------------------------------ |
222/// | `boolean` | `bool` | `bool` | `bool` |
223/// | `int8` | | `i8` | `i8` |
224/// | `int16` | `smallint` | `i16` | `i16` |
225/// | `int32` | `int` | `i32` | `i32` |
226/// | `int64` | `bigint` | `i64` | `i64` |
227/// | `float32` | `real` | `f32` | `f32` |
228/// | `float32` | `double precision` | `f64` | `f64` |
229/// | `date32` | `date` | [`chrono::NaiveDate`] | [`chrono::NaiveDate`] |
230/// | `time64` | `time` | [`chrono::NaiveTime`] | [`chrono::NaiveTime`] |
231/// | `timestamp` | | [`chrono::NaiveDateTime`] | [`chrono::NaiveDateTime`] |
232/// | `timestamptz` | | not supported yet | not supported yet |
233/// | `interval` | | [`arrow_udf::types::Interval`] | [`arrow_udf::types::Interval`] |
234/// | `string` | `varchar` | `&str` | `impl AsRef<str>`, e.g. `String`, `Box<str>`, `&str` |
235/// | `binary` | `bytea` | `&[u8]` | `impl AsRef<[u8]>`, e.g. `Vec<u8>`, `Box<[u8]>`, `&[u8]` |
236///
237/// ## Extension Types
238///
239/// We also support the following extension types that are not part of the Arrow data types:
240///
241/// | Data type | Metadata | Rust type as argument | Rust type as return value |
242/// | ----------- | ------------------- | ------------------------------ | ------------------------------ |
243/// | `decimal` | `arrowudf.decimal` | [`rust_decimal::Decimal`] | [`rust_decimal::Decimal`] |
244/// | `json` | `arrowudf.json` | [`serde_json::Value`] | [`serde_json::Value`] |
245///
246/// ## Array Types
247///
248/// | SQL type | Rust type as argument | Rust type as return value |
249/// | -------------------- | ------------------------- | ------------------------------ |
250/// | `int8[]` | `&[i8]` | `impl Iterator<Item = i8>` |
251/// | `int16[]` | `&[i16]` | `impl Iterator<Item = i16>` |
252/// | `int32[]` | `&[i32]` | `impl Iterator<Item = i32>` |
253/// | `int64[]` | `&[i64]` | `impl Iterator<Item = i64>` |
254/// | `float32[]` | `&[f32]` | `impl Iterator<Item = f32>` |
255/// | `float64[]` | `&[f64]` | `impl Iterator<Item = f64>` |
256/// | `string[]` | [`&StringArray`] | `impl Iterator<Item = &str>` |
257/// | `binary[]` | [`&BinaryArray`] | `impl Iterator<Item = &[u8]>` |
258/// | `largestring[]` | [`&LargeStringArray`] | `impl Iterator<Item = &str>` |
259/// | `largebinary[]` | [`&LargeBinaryArray`] | `impl Iterator<Item = &[u8]>` |
260/// | `others[]` | not supported yet | not supported yet |
261///
262/// ## Composite Types
263///
264/// | SQL type | Rust type as argument | Rust type as return value |
265/// | --------------------- | ------------------------- | ------------------------------ |
266/// | `struct<..>` | `UserDefinedStruct` | `UserDefinedStruct` |
267///
268/// [type matrix]: #appendix-type-matrix
269/// [`rust_decimal::Decimal`]: https://docs.rs/rust_decimal/1.33.1/rust_decimal/struct.Decimal.html
270/// [`chrono::NaiveDate`]: https://docs.rs/chrono/0.4.31/chrono/naive/struct.NaiveDate.html
271/// [`chrono::NaiveTime`]: https://docs.rs/chrono/0.4.31/chrono/naive/struct.NaiveTime.html
272/// [`chrono::NaiveDateTime`]: https://docs.rs/chrono/0.4.31/chrono/naive/struct.NaiveDateTime.html
273/// [`arrow_udf::types::Interval`]: https://docs.rs/arrow_udf/0.1.0/arrow_udf/types/struct.Interval.html
274/// [`serde_json::Value`]: https://docs.rs/serde_json/1.0.108/serde_json/enum.Value.html
275/// [`&StringArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.StringArray.html
276/// [`&BinaryArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.BinaryArray.html
277/// [`&LargeStringArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.LargeStringArray.html
278/// [`&LargeBinaryArray`]: https://docs.rs/arrow/50.0.0/arrow/array/type.LargeBinaryArray.html
279#[proc_macro_attribute]
280pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
281 fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
282 let fn_attr: FunctionAttr = syn::parse(attr)?;
283 let user_fn: UserFunctionAttr = syn::parse(item.clone())?;
284
285 let mut tokens: TokenStream2 = item.into();
286 for attr in fn_attr.expand() {
287 tokens.extend(attr.generate_function_descriptor(&user_fn)?);
288 }
289 Ok(tokens)
290 }
291 match inner(attr, item) {
292 Ok(tokens) => tokens.into(),
293 Err(e) => e.to_compile_error().into(),
294 }
295}
296
297#[derive(Debug, Clone, Default)]
298struct FunctionAttr {
299 /// Function name
300 name: String,
301 /// Input argument types
302 args: Vec<String>,
303 /// Return type
304 ret: String,
305 /// Whether it is a table function
306 is_table_function: bool,
307 /// Whether it is an append-only aggregate function
308 append_only: bool,
309 /// Optional function for batch evaluation.
310 batch_fn: Option<String>,
311 /// State type for aggregate function.
312 /// If not specified, it will be the same as return type.
313 state: Option<String>,
314 /// Initial state value for aggregate function.
315 /// If not specified, it will be NULL.
316 init_state: Option<String>,
317 /// Type inference function.
318 type_infer: Option<String>,
319 /// Generic type.
320 generic: Option<String>,
321 /// Whether the function is volatile.
322 volatile: bool,
323 /// Generated batch function name.
324 /// If not specified, the macro will not generate batch function.
325 output: Option<String>,
326 /// If specified, the macro will generate DuckDB struct with this name.
327 duckdb: Option<String>,
328 /// Customized function visibility.
329 visibility: Option<String>,
330}
331
332/// Attributes from function signature `fn(..)`
333#[derive(Debug, Clone)]
334#[allow(dead_code)]
335struct UserFunctionAttr {
336 /// Function name
337 name: String,
338 /// Whether the function is async.
339 async_: bool,
340 /// Whether contains argument `&Context`.
341 context: bool,
342 /// Whether contains argument `&mut impl Write`.
343 write: bool,
344 /// Whether the last argument type is `retract: bool`.
345 retract: bool,
346 /// Whether each argument type is `Option<T>`.
347 args_option: Vec<bool>,
348 /// If the first argument type is `&mut T`, then `Some(T)`.
349 first_mut_ref_arg: Option<String>,
350 /// The return type kind.
351 return_type_kind: ReturnTypeKind,
352 /// The kind of inner type `T` in `impl Iterator<Item = T>`
353 iterator_item_kind: Option<ReturnTypeKind>,
354 /// The core return type without `Option` or `Result`.
355 core_return_type: String,
356 /// The number of generic types.
357 generic: usize,
358 /// The span of return type.
359 return_type_span: proc_macro2::Span,
360}
361
362#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
363enum ReturnTypeKind {
364 T,
365 Option,
366 Result,
367 ResultOption,
368}
369
370impl ReturnTypeKind {
371 /// Returns true if the type is `Result<..>`.
372 const fn is_result(&self) -> bool {
373 matches!(self, ReturnTypeKind::Result | ReturnTypeKind::ResultOption)
374 }
375}
376
377impl FunctionAttr {
378 /// Return a unique name that can be used as an identifier.
379 fn ident_name(&self) -> String {
380 format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
381 .replace("[]", "array")
382 .replace("...", "variadic")
383 .replace(['<', ' ', ',', ':'], "_")
384 .replace('>', "")
385 .replace("__", "_")
386 }
387
388 /// Return a unique signature of the function.
389 fn normalize_signature(&self) -> String {
390 format!(
391 "{}({}){}{}",
392 self.name,
393 self.args.join(","),
394 if self.is_table_function { "->>" } else { "->" },
395 self.ret
396 )
397 }
398}
399
400impl UserFunctionAttr {
401 /// Returns true if the function is like `fn(T1, T2, .., Tn) -> T`.
402 fn is_pure(&self) -> bool {
403 !self.async_
404 && !self.write
405 && !self.context
406 && self.args_option.iter().all(|b| !b)
407 && self.return_type_kind == ReturnTypeKind::T
408 }
409
410 /// Returns true if the function may return error.
411 fn has_error(&self) -> bool {
412 self.return_type_kind.is_result()
413 || matches!(&self.iterator_item_kind, Some(k) if k.is_result())
414 }
415}