1use anyhow::{anyhow, bail};
2use arrow::datatypes::{DataType, Field, TimeUnit};
3use regex::Regex;
4use std::sync::Arc;
5use std::time::Duration;
6use syn::PathArguments::AngleBracketed;
7use syn::__private::ToTokens;
8use syn::{FnArg, GenericArgument, ItemFn, LitInt, LitStr, ReturnType, Type};
9
10#[derive(Clone, Debug, Eq, PartialEq)]
12pub struct NullableType {
13 pub data_type: DataType,
14 pub nullable: bool,
15}
16
17impl NullableType {
18 pub fn new(data_type: DataType, nullable: bool) -> Self {
19 Self {
20 data_type,
21 nullable,
22 }
23 }
24
25 pub fn null(data_type: DataType) -> Self {
26 Self {
27 data_type,
28 nullable: true,
29 }
30 }
31
32 pub fn not_null(data_type: DataType) -> Self {
33 Self {
34 data_type,
35 nullable: false,
36 }
37 }
38
39 pub fn with_nullability(&self, nullable: bool) -> Self {
40 Self {
41 data_type: self.data_type.clone(),
42 nullable,
43 }
44 }
45}
46
47pub fn is_vec_u8(typ: &Type) -> bool {
48 let Some(inner) = ParsedUdf::vec_inner_type(typ) else {
49 return false;
50 };
51
52 matches!(
53 rust_to_arrow(&inner, true),
54 Ok(NullableType {
55 data_type: DataType::UInt8,
56 nullable: false
57 })
58 )
59}
60
61pub(crate) fn rust_to_arrow(typ: &Type, expect_owned: bool) -> anyhow::Result<NullableType> {
62 match typ {
63 Type::Path(pat) => {
64 let last = pat.path.segments.last().unwrap();
65 if last.ident == "Option" {
66 let AngleBracketed(args) = &last.arguments else {
67 bail!("invalid Rust type; Option must have arguments");
68 };
69
70 let Some(GenericArgument::Type(inner)) = args.args.first() else {
71 bail!("invalid Rust type; Option must have an inner type parameter")
72 };
73
74 Ok(rust_to_arrow(inner, expect_owned)?.with_nullability(true))
75 } else {
76 let mut dt = rust_primitive_to_arrow(typ);
77
78 if dt.is_none() {
79 dt = Some(
80 match (
81 render_path(typ)
82 .ok_or_else(|| anyhow!("unsupported Rust type1"))?
83 .as_str(),
84 expect_owned,
85 ) {
86 ("String", true) => DataType::Utf8,
87 ("String", false) => {
88 bail!("expected reference type &str instead of String")
89 }
90 ("Vec<u8>", true) => DataType::Binary,
91 ("Vec<u8>", false) => {
92 bail!("expected reference type &[u8] instead of Vec<u8>")
93 }
94 (t, _) => bail!("unsupported Rust type {}", t),
95 },
96 );
97 }
98
99 Ok(NullableType::not_null(
100 dt.ok_or_else(|| anyhow!("unsupported Rust type2"))?,
101 ))
102 }
103 }
104 Type::Reference(r) => {
105 let t = render_path(&r.elem).ok_or_else(|| anyhow!("unsupported Rust type3"))?;
106
107 let dt = match (t.as_str(), rust_primitive_to_arrow(&r.elem), expect_owned) {
108 ("String", _, false) => bail!("expected &str, not &String"),
109 ("String", _, true) => {
110 bail!("expected owned String, not &String (hint: remove the &)")
111 }
112 ("Vec<u8>", _, false) => bail!("expected &[u8], not &Vec<u8>"),
113 ("Vec<u8>", _, true) => {
114 bail!("expected owned Vec<u8>, not &Vec<u8> (hint: remove the &)")
115 }
116 ("str", _, false) => DataType::Utf8,
117 ("str", _, true) => bail!("expected owned String, not &str"),
118 ("[u8]", _, false) => DataType::Binary,
119 ("[u8]", _, true) => bail!("expected owned Vec<u8>, not &[u8]"),
120 (t, Some(_), _) => bail!(
121 "unexpected &{}; primitives should be passed by value (hint: remove the &)",
122 t
123 ),
124 _ => {
125 bail!("unsupported Rust data type")
126 }
127 };
128
129 Ok(NullableType::not_null(dt))
130 }
131 _ => bail!("unsupported Rust data type"),
132 }
133}
134
135fn render_path(typ: &Type) -> Option<String> {
136 match typ {
137 Type::Path(pat) => {
138 let path: Vec<String> = pat
139 .path
140 .segments
141 .iter()
142 .map(|s| s.to_token_stream().to_string().replace(' ', ""))
143 .collect();
144
145 Some(path.join("::"))
146 }
147 Type::Slice(t) => Some(format!("[{}]", render_path(&t.elem)?)),
148 _ => None,
149 }
150}
151
152fn rust_primitive_to_arrow(typ: &Type) -> Option<DataType> {
153 match render_path(typ)?.as_str() {
154 "bool" => Some(DataType::Boolean),
155 "i8" => Some(DataType::Int8),
156 "i16" => Some(DataType::Int16),
157 "i32" => Some(DataType::Int32),
158 "i64" => Some(DataType::Int64),
159 "u8" => Some(DataType::UInt8),
160 "u16" => Some(DataType::UInt16),
161 "u32" => Some(DataType::UInt32),
162 "u64" => Some(DataType::UInt64),
163 "f16" => Some(DataType::Float16),
164 "f32" => Some(DataType::Float32),
165 "f64" => Some(DataType::Float64),
166 "SystemTime" | "std::time::SystemTime" => {
167 Some(DataType::Timestamp(TimeUnit::Microsecond, None))
168 }
169 "Duration" | "std::time::Duration" => Some(DataType::Duration(TimeUnit::Microsecond)),
170 _ => None,
171 }
172}
173
174#[derive(Clone, Debug)]
175pub struct UdfDef {
176 pub args: Vec<NullableType>,
177 pub ret: NullableType,
178 pub aggregate: bool,
179 pub udf_type: UdfType,
180}
181
182#[derive(Copy, Clone, Debug, Eq, PartialEq)]
183pub struct AsyncOptions {
184 pub ordered: bool,
185 pub timeout: Duration,
186 pub max_concurrency: usize,
187}
188
189impl Default for AsyncOptions {
190 fn default() -> Self {
191 Self {
192 ordered: false,
193 timeout: Duration::from_secs(5),
194 max_concurrency: 1000,
195 }
196 }
197}
198
199#[derive(Copy, Clone, Debug, Eq, PartialEq)]
200pub enum UdfType {
201 Sync,
202 Async(AsyncOptions),
203}
204
205impl UdfType {
206 pub fn is_async(&self) -> bool {
207 !matches!(self, UdfType::Sync)
208 }
209}
210
211fn parse_duration(input: &str) -> anyhow::Result<Duration> {
212 let r = Regex::new(r"^(\d+)\s*([a-zA-Zµ]+)$").unwrap();
213 let captures = r
214 .captures(input)
215 .ok_or_else(|| anyhow!("invalid duration specification '{}'", input))?;
216 let mut capture = captures.iter();
217
218 capture.next();
219
220 let n: u64 = capture.next().unwrap().unwrap().as_str().parse().unwrap();
221 let unit = capture.next().unwrap().unwrap().as_str();
222
223 Ok(match unit {
224 "ns" | "nanos" => Duration::from_nanos(n),
225 "µs" | "micros" => Duration::from_micros(n),
226 "ms" | "millis" => Duration::from_millis(n),
227 "s" | "secs" | "seconds" => Duration::from_secs(n),
228 "m" | "mins" | "minutes" => Duration::from_secs(n * 60),
229 "h" | "hrs" | "hours" => Duration::from_secs(n * 60 * 60),
230 x => bail!("unknown time unit '{}'", x),
231 })
232}
233
234pub struct ParsedUdf {
235 pub function: String,
236 pub name: String,
237 pub args: Vec<NullableType>,
238 pub vec_arguments: usize,
239 pub ret_type: NullableType,
240 pub udf_type: UdfType,
241}
242
243impl ParsedUdf {
244 pub fn vec_inner_type(ty: &syn::Type) -> Option<syn::Type> {
245 if let syn::Type::Path(syn::TypePath { path, .. }) = ty {
246 if let Some(segment) = path.segments.last() {
247 if segment.ident == "Vec" {
248 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
249 if args.args.len() == 1 {
250 if let syn::GenericArgument::Type(inner_ty) = &args.args[0] {
251 return Some(inner_ty.clone());
252 }
253 }
254 }
255 }
256 }
257 }
258 None
259 }
260
261 pub fn try_parse(function: &ItemFn) -> anyhow::Result<ParsedUdf> {
262 let name = function.sig.ident.to_string();
263 let mut args = vec![];
264 let mut vec_arguments = 0;
265 for (i, arg) in function.sig.inputs.iter().enumerate() {
266 match arg {
267 FnArg::Receiver(_) => {
268 bail!(
269 "Function {} has a 'self' argument, which is not allowed",
270 name
271 )
272 }
273 FnArg::Typed(t) => {
274 let vec_type = Self::vec_inner_type(&t.ty);
275 if vec_type.is_some() {
276 vec_arguments += 1;
277 let vec_type = rust_to_arrow(vec_type.as_ref().unwrap(), false).map_err(|e| {
278 anyhow!(
279 "Could not convert function {name} inner vector arg {i} into an Arrow data type: {e}",
280 )
281 })?;
282
283 args.push(NullableType::not_null(DataType::List(Arc::new(
284 Field::new("item", vec_type.data_type, vec_type.nullable),
285 ))));
286 } else {
287 args.push(rust_to_arrow(&t.ty, false).map_err(|e| {
288 anyhow!(
289 "Could not convert function {name} arg {i} into a SQL data type: {e}",
290 )
291 })?);
292 }
293 }
294 }
295 }
296
297 let ret = match &function.sig.output {
298 ReturnType::Default => bail!("Function {} return type must be specified", name),
299 ReturnType::Type(_, t) => rust_to_arrow(t, true).map_err(|e| {
300 anyhow!("Could not convert function {name} return type into a SQL data type: {e}",)
301 })?,
302 };
303
304 let udf_type = if function.sig.asyncness.is_some() {
305 let mut t = AsyncOptions::default();
306
307 if let Some(attr) = function
308 .attrs
309 .iter()
310 .find(|attr| attr.path().is_ident("udf"))
311 {
312 if attr.meta.require_path_only().is_err() {
313 attr.parse_nested_meta(|meta| {
314 if meta.path.is_ident("ordered") {
315 t.ordered = true;
316 } else if meta.path.is_ident("unordered") {
317 t.ordered = false;
318 } else if meta.path.is_ident("allowed_in_flight") {
319 let value = meta.value()?;
320 let s: LitInt = value.parse()?;
321 let n: usize = s
322 .base10_digits()
323 .parse()
324 .map_err(|_| meta.error("expected number"))?;
325 t.max_concurrency = n;
326 } else if meta.path.is_ident("timeout") {
327 let value = meta.value()?;
328 let s: LitStr = value.parse()?;
329 t.timeout = parse_duration(&s.value()).map_err(|e| meta.error(e))?;
330 } else {
331 return Err(meta.error(format!(
332 "unsupported attribute '{}'",
333 meta.path.to_token_stream()
334 )));
335 }
336 Ok(())
337 })?;
338 }
339 }
340
341 UdfType::Async(t)
342 } else {
343 UdfType::Sync
344 };
345
346 Ok(ParsedUdf {
347 function: function.into_token_stream().to_string(),
348 name,
349 args,
350 vec_arguments,
351 ret_type: ret,
352 udf_type,
353 })
354 }
355}
356
357pub fn inner_type(dt: &DataType) -> Option<DataType> {
358 match dt {
359 DataType::List(f) => Some(f.data_type().clone()),
360 _ => None,
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use crate::parse::{parse_duration, rust_to_arrow, NullableType};
367 use arrow::datatypes::DataType;
368 use std::time::Duration;
369 use syn::parse_quote;
370
371 #[test]
372 fn test_duration() {
373 assert_eq!(Duration::from_secs(5), parse_duration("5s").unwrap());
374 assert_eq!(Duration::from_secs(5), parse_duration("5 seconds").unwrap());
375 assert_eq!(Duration::from_secs(5), parse_duration("5 secs").unwrap());
376
377 assert_eq!(Duration::from_millis(10), parse_duration("10ms").unwrap());
378 assert_eq!(
379 Duration::from_millis(110),
380 parse_duration("110millis").unwrap()
381 );
382
383 assert!(parse_duration("-10ms").is_err());
384 assert!(parse_duration("10.0s").is_err());
385 assert!(parse_duration("5s what").is_err());
386 }
387
388 #[test]
389 fn test_rust_to_arrow() {
390 assert_eq!(
391 rust_to_arrow(&parse_quote!(i32), false).unwrap(),
392 NullableType::not_null(DataType::Int32)
393 );
394 assert_eq!(
395 rust_to_arrow(&parse_quote!(Option<i32>), false).unwrap(),
396 NullableType::null(DataType::Int32)
397 );
398 assert_eq!(
399 rust_to_arrow(&parse_quote!(Vec<u8>), true).unwrap(),
400 NullableType::not_null(DataType::Binary)
401 );
402 assert_eq!(
403 rust_to_arrow(&parse_quote!(&[u8]), false).unwrap(),
404 NullableType::not_null(DataType::Binary)
405 );
406 assert_eq!(
407 rust_to_arrow(&parse_quote!(Vec<u8>), true).unwrap(),
408 NullableType::not_null(DataType::Binary)
409 );
410
411 assert_eq!(
412 rust_to_arrow(&parse_quote!(u64), false).unwrap(),
413 NullableType::not_null(DataType::UInt64)
414 );
415 assert_eq!(
416 rust_to_arrow(&parse_quote!(f32), false).unwrap(),
417 NullableType::not_null(DataType::Float32)
418 );
419 assert_eq!(
420 rust_to_arrow(&parse_quote!(bool), false).unwrap(),
421 NullableType::not_null(DataType::Boolean)
422 );
423
424 assert_eq!(
425 rust_to_arrow(&parse_quote!(Option<f64>), false).unwrap(),
426 NullableType::null(DataType::Float64)
427 );
428 assert_eq!(
429 rust_to_arrow(&parse_quote!(Option<bool>), false).unwrap(),
430 NullableType::null(DataType::Boolean)
431 );
432
433 assert_eq!(
434 rust_to_arrow(&parse_quote!(String), true).unwrap(),
435 NullableType::not_null(DataType::Utf8)
436 );
437 assert_eq!(
438 rust_to_arrow(&parse_quote!(&str), false).unwrap(),
439 NullableType::not_null(DataType::Utf8)
440 );
441
442 assert_eq!(
443 rust_to_arrow(&parse_quote!(Option<String>), true).unwrap(),
444 NullableType::null(DataType::Utf8)
445 );
446 assert_eq!(
447 rust_to_arrow(&parse_quote!(Option<&str>), false).unwrap(),
448 NullableType::null(DataType::Utf8)
449 );
450
451 assert_eq!(
452 rust_to_arrow(&parse_quote!(HashMap<String, i32>), false).ok(),
453 None
454 );
455 assert_eq!(rust_to_arrow(&parse_quote!(CustomStruct), false).ok(), None);
456
457 assert_eq!(rust_to_arrow(&parse_quote!(Vec<u8>), false).ok(), None);
458 assert_eq!(rust_to_arrow(&parse_quote!(&[u8]), true).ok(), None);
459 }
460}