arroyo_udf_common/
parse.rs

1use anyhow::{anyhow, bail};
2use arrow::datatypes::{DataType, Field, TimeUnit};
3use regex::Regex;
4use std::sync::Arc;
5use std::time::Duration;
6use syn::PathArguments::AngleBracketed;
7use syn::__private::ToTokens;
8use syn::{FnArg, GenericArgument, ItemFn, LitInt, LitStr, ReturnType, Type};
9
10/// An Arrow DataType that also carries around its own nullability info
11#[derive(Clone, Debug, Eq, PartialEq)]
12pub struct NullableType {
13    pub data_type: DataType,
14    pub nullable: bool,
15}
16
17impl NullableType {
18    pub fn new(data_type: DataType, nullable: bool) -> Self {
19        Self {
20            data_type,
21            nullable,
22        }
23    }
24
25    pub fn null(data_type: DataType) -> Self {
26        Self {
27            data_type,
28            nullable: true,
29        }
30    }
31
32    pub fn not_null(data_type: DataType) -> Self {
33        Self {
34            data_type,
35            nullable: false,
36        }
37    }
38
39    pub fn with_nullability(&self, nullable: bool) -> Self {
40        Self {
41            data_type: self.data_type.clone(),
42            nullable,
43        }
44    }
45}
46
47pub fn is_vec_u8(typ: &Type) -> bool {
48    let Some(inner) = ParsedUdf::vec_inner_type(typ) else {
49        return false;
50    };
51
52    matches!(
53        rust_to_arrow(&inner, true),
54        Ok(NullableType {
55            data_type: DataType::UInt8,
56            nullable: false
57        })
58    )
59}
60
61pub(crate) fn rust_to_arrow(typ: &Type, expect_owned: bool) -> anyhow::Result<NullableType> {
62    match typ {
63        Type::Path(pat) => {
64            let last = pat.path.segments.last().unwrap();
65            if last.ident == "Option" {
66                let AngleBracketed(args) = &last.arguments else {
67                    bail!("invalid Rust type; Option must have arguments");
68                };
69
70                let Some(GenericArgument::Type(inner)) = args.args.first() else {
71                    bail!("invalid Rust type; Option must have an inner type parameter")
72                };
73
74                Ok(rust_to_arrow(inner, expect_owned)?.with_nullability(true))
75            } else {
76                let mut dt = rust_primitive_to_arrow(typ);
77
78                if dt.is_none() {
79                    dt = Some(
80                        match (
81                            render_path(typ)
82                                .ok_or_else(|| anyhow!("unsupported Rust type1"))?
83                                .as_str(),
84                            expect_owned,
85                        ) {
86                            ("String", true) => DataType::Utf8,
87                            ("String", false) => {
88                                bail!("expected reference type &str instead of String")
89                            }
90                            ("Vec<u8>", true) => DataType::Binary,
91                            ("Vec<u8>", false) => {
92                                bail!("expected reference type &[u8] instead of Vec<u8>")
93                            }
94                            (t, _) => bail!("unsupported Rust type {}", t),
95                        },
96                    );
97                }
98
99                Ok(NullableType::not_null(
100                    dt.ok_or_else(|| anyhow!("unsupported Rust type2"))?,
101                ))
102            }
103        }
104        Type::Reference(r) => {
105            let t = render_path(&r.elem).ok_or_else(|| anyhow!("unsupported Rust type3"))?;
106
107            let dt = match (t.as_str(), rust_primitive_to_arrow(&r.elem), expect_owned) {
108                ("String", _, false) => bail!("expected &str, not &String"),
109                ("String", _, true) => {
110                    bail!("expected owned String, not &String (hint: remove the &)")
111                }
112                ("Vec<u8>", _, false) => bail!("expected &[u8], not &Vec<u8>"),
113                ("Vec<u8>", _, true) => {
114                    bail!("expected owned Vec<u8>, not &Vec<u8> (hint: remove the &)")
115                }
116                ("str", _, false) => DataType::Utf8,
117                ("str", _, true) => bail!("expected owned String, not &str"),
118                ("[u8]", _, false) => DataType::Binary,
119                ("[u8]", _, true) => bail!("expected owned Vec<u8>, not &[u8]"),
120                (t, Some(_), _) => bail!(
121                    "unexpected &{}; primitives should be passed by value (hint: remove the &)",
122                    t
123                ),
124                _ => {
125                    bail!("unsupported Rust data type")
126                }
127            };
128
129            Ok(NullableType::not_null(dt))
130        }
131        _ => bail!("unsupported Rust data type"),
132    }
133}
134
135fn render_path(typ: &Type) -> Option<String> {
136    match typ {
137        Type::Path(pat) => {
138            let path: Vec<String> = pat
139                .path
140                .segments
141                .iter()
142                .map(|s| s.to_token_stream().to_string().replace(' ', ""))
143                .collect();
144
145            Some(path.join("::"))
146        }
147        Type::Slice(t) => Some(format!("[{}]", render_path(&t.elem)?)),
148        _ => None,
149    }
150}
151
152fn rust_primitive_to_arrow(typ: &Type) -> Option<DataType> {
153    match render_path(typ)?.as_str() {
154        "bool" => Some(DataType::Boolean),
155        "i8" => Some(DataType::Int8),
156        "i16" => Some(DataType::Int16),
157        "i32" => Some(DataType::Int32),
158        "i64" => Some(DataType::Int64),
159        "u8" => Some(DataType::UInt8),
160        "u16" => Some(DataType::UInt16),
161        "u32" => Some(DataType::UInt32),
162        "u64" => Some(DataType::UInt64),
163        "f16" => Some(DataType::Float16),
164        "f32" => Some(DataType::Float32),
165        "f64" => Some(DataType::Float64),
166        "SystemTime" | "std::time::SystemTime" => {
167            Some(DataType::Timestamp(TimeUnit::Microsecond, None))
168        }
169        "Duration" | "std::time::Duration" => Some(DataType::Duration(TimeUnit::Microsecond)),
170        _ => None,
171    }
172}
173
174#[derive(Clone, Debug)]
175pub struct UdfDef {
176    pub args: Vec<NullableType>,
177    pub ret: NullableType,
178    pub aggregate: bool,
179    pub udf_type: UdfType,
180}
181
182#[derive(Copy, Clone, Debug, Eq, PartialEq)]
183pub struct AsyncOptions {
184    pub ordered: bool,
185    pub timeout: Duration,
186    pub max_concurrency: usize,
187}
188
189impl Default for AsyncOptions {
190    fn default() -> Self {
191        Self {
192            ordered: false,
193            timeout: Duration::from_secs(5),
194            max_concurrency: 1000,
195        }
196    }
197}
198
199#[derive(Copy, Clone, Debug, Eq, PartialEq)]
200pub enum UdfType {
201    Sync,
202    Async(AsyncOptions),
203}
204
205impl UdfType {
206    pub fn is_async(&self) -> bool {
207        !matches!(self, UdfType::Sync)
208    }
209}
210
211fn parse_duration(input: &str) -> anyhow::Result<Duration> {
212    let r = Regex::new(r"^(\d+)\s*([a-zA-Zµ]+)$").unwrap();
213    let captures = r
214        .captures(input)
215        .ok_or_else(|| anyhow!("invalid duration specification '{}'", input))?;
216    let mut capture = captures.iter();
217
218    capture.next();
219
220    let n: u64 = capture.next().unwrap().unwrap().as_str().parse().unwrap();
221    let unit = capture.next().unwrap().unwrap().as_str();
222
223    Ok(match unit {
224        "ns" | "nanos" => Duration::from_nanos(n),
225        "µs" | "micros" => Duration::from_micros(n),
226        "ms" | "millis" => Duration::from_millis(n),
227        "s" | "secs" | "seconds" => Duration::from_secs(n),
228        "m" | "mins" | "minutes" => Duration::from_secs(n * 60),
229        "h" | "hrs" | "hours" => Duration::from_secs(n * 60 * 60),
230        x => bail!("unknown time unit '{}'", x),
231    })
232}
233
234pub struct ParsedUdf {
235    pub function: String,
236    pub name: String,
237    pub args: Vec<NullableType>,
238    pub vec_arguments: usize,
239    pub ret_type: NullableType,
240    pub udf_type: UdfType,
241}
242
243impl ParsedUdf {
244    pub fn vec_inner_type(ty: &syn::Type) -> Option<syn::Type> {
245        if let syn::Type::Path(syn::TypePath { path, .. }) = ty {
246            if let Some(segment) = path.segments.last() {
247                if segment.ident == "Vec" {
248                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
249                        if args.args.len() == 1 {
250                            if let syn::GenericArgument::Type(inner_ty) = &args.args[0] {
251                                return Some(inner_ty.clone());
252                            }
253                        }
254                    }
255                }
256            }
257        }
258        None
259    }
260
261    pub fn try_parse(function: &ItemFn) -> anyhow::Result<ParsedUdf> {
262        let name = function.sig.ident.to_string();
263        let mut args = vec![];
264        let mut vec_arguments = 0;
265        for (i, arg) in function.sig.inputs.iter().enumerate() {
266            match arg {
267                FnArg::Receiver(_) => {
268                    bail!(
269                        "Function {} has a 'self' argument, which is not allowed",
270                        name
271                    )
272                }
273                FnArg::Typed(t) => {
274                    let vec_type = Self::vec_inner_type(&t.ty);
275                    if vec_type.is_some() {
276                        vec_arguments += 1;
277                        let vec_type = rust_to_arrow(vec_type.as_ref().unwrap(), false).map_err(|e| {
278                            anyhow!(
279                                "Could not convert function {name} inner vector arg {i} into an Arrow data type: {e}",
280                            )
281                        })?;
282
283                        args.push(NullableType::not_null(DataType::List(Arc::new(
284                            Field::new("item", vec_type.data_type, vec_type.nullable),
285                        ))));
286                    } else {
287                        args.push(rust_to_arrow(&t.ty, false).map_err(|e| {
288                            anyhow!(
289                                "Could not convert function {name} arg {i} into a SQL data type: {e}",
290                            )
291                        })?);
292                    }
293                }
294            }
295        }
296
297        let ret = match &function.sig.output {
298            ReturnType::Default => bail!("Function {} return type must be specified", name),
299            ReturnType::Type(_, t) => rust_to_arrow(t, true).map_err(|e| {
300                anyhow!("Could not convert function {name} return type into a SQL data type: {e}",)
301            })?,
302        };
303
304        let udf_type = if function.sig.asyncness.is_some() {
305            let mut t = AsyncOptions::default();
306
307            if let Some(attr) = function
308                .attrs
309                .iter()
310                .find(|attr| attr.path().is_ident("udf"))
311            {
312                if attr.meta.require_path_only().is_err() {
313                    attr.parse_nested_meta(|meta| {
314                        if meta.path.is_ident("ordered") {
315                            t.ordered = true;
316                        } else if meta.path.is_ident("unordered") {
317                            t.ordered = false;
318                        } else if meta.path.is_ident("allowed_in_flight") {
319                            let value = meta.value()?;
320                            let s: LitInt = value.parse()?;
321                            let n: usize = s
322                                .base10_digits()
323                                .parse()
324                                .map_err(|_| meta.error("expected number"))?;
325                            t.max_concurrency = n;
326                        } else if meta.path.is_ident("timeout") {
327                            let value = meta.value()?;
328                            let s: LitStr = value.parse()?;
329                            t.timeout = parse_duration(&s.value()).map_err(|e| meta.error(e))?;
330                        } else {
331                            return Err(meta.error(format!(
332                                "unsupported attribute '{}'",
333                                meta.path.to_token_stream()
334                            )));
335                        }
336                        Ok(())
337                    })?;
338                }
339            }
340
341            UdfType::Async(t)
342        } else {
343            UdfType::Sync
344        };
345
346        Ok(ParsedUdf {
347            function: function.into_token_stream().to_string(),
348            name,
349            args,
350            vec_arguments,
351            ret_type: ret,
352            udf_type,
353        })
354    }
355}
356
357pub fn inner_type(dt: &DataType) -> Option<DataType> {
358    match dt {
359        DataType::List(f) => Some(f.data_type().clone()),
360        _ => None,
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use crate::parse::{parse_duration, rust_to_arrow, NullableType};
367    use arrow::datatypes::DataType;
368    use std::time::Duration;
369    use syn::parse_quote;
370
371    #[test]
372    fn test_duration() {
373        assert_eq!(Duration::from_secs(5), parse_duration("5s").unwrap());
374        assert_eq!(Duration::from_secs(5), parse_duration("5 seconds").unwrap());
375        assert_eq!(Duration::from_secs(5), parse_duration("5   secs").unwrap());
376
377        assert_eq!(Duration::from_millis(10), parse_duration("10ms").unwrap());
378        assert_eq!(
379            Duration::from_millis(110),
380            parse_duration("110millis").unwrap()
381        );
382
383        assert!(parse_duration("-10ms").is_err());
384        assert!(parse_duration("10.0s").is_err());
385        assert!(parse_duration("5s what").is_err());
386    }
387
388    #[test]
389    fn test_rust_to_arrow() {
390        assert_eq!(
391            rust_to_arrow(&parse_quote!(i32), false).unwrap(),
392            NullableType::not_null(DataType::Int32)
393        );
394        assert_eq!(
395            rust_to_arrow(&parse_quote!(Option<i32>), false).unwrap(),
396            NullableType::null(DataType::Int32)
397        );
398        assert_eq!(
399            rust_to_arrow(&parse_quote!(Vec<u8>), true).unwrap(),
400            NullableType::not_null(DataType::Binary)
401        );
402        assert_eq!(
403            rust_to_arrow(&parse_quote!(&[u8]), false).unwrap(),
404            NullableType::not_null(DataType::Binary)
405        );
406        assert_eq!(
407            rust_to_arrow(&parse_quote!(Vec<u8>), true).unwrap(),
408            NullableType::not_null(DataType::Binary)
409        );
410
411        assert_eq!(
412            rust_to_arrow(&parse_quote!(u64), false).unwrap(),
413            NullableType::not_null(DataType::UInt64)
414        );
415        assert_eq!(
416            rust_to_arrow(&parse_quote!(f32), false).unwrap(),
417            NullableType::not_null(DataType::Float32)
418        );
419        assert_eq!(
420            rust_to_arrow(&parse_quote!(bool), false).unwrap(),
421            NullableType::not_null(DataType::Boolean)
422        );
423
424        assert_eq!(
425            rust_to_arrow(&parse_quote!(Option<f64>), false).unwrap(),
426            NullableType::null(DataType::Float64)
427        );
428        assert_eq!(
429            rust_to_arrow(&parse_quote!(Option<bool>), false).unwrap(),
430            NullableType::null(DataType::Boolean)
431        );
432
433        assert_eq!(
434            rust_to_arrow(&parse_quote!(String), true).unwrap(),
435            NullableType::not_null(DataType::Utf8)
436        );
437        assert_eq!(
438            rust_to_arrow(&parse_quote!(&str), false).unwrap(),
439            NullableType::not_null(DataType::Utf8)
440        );
441
442        assert_eq!(
443            rust_to_arrow(&parse_quote!(Option<String>), true).unwrap(),
444            NullableType::null(DataType::Utf8)
445        );
446        assert_eq!(
447            rust_to_arrow(&parse_quote!(Option<&str>), false).unwrap(),
448            NullableType::null(DataType::Utf8)
449        );
450
451        assert_eq!(
452            rust_to_arrow(&parse_quote!(HashMap<String, i32>), false).ok(),
453            None
454        );
455        assert_eq!(rust_to_arrow(&parse_quote!(CustomStruct), false).ok(), None);
456
457        assert_eq!(rust_to_arrow(&parse_quote!(Vec<u8>), false).ok(), None);
458        assert_eq!(rust_to_arrow(&parse_quote!(&[u8]), true).ok(), None);
459    }
460}