Skip to main content

helix_dsl_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    FnArg, GenericArgument, ItemFn, Pat, PathArguments, Type, parse_macro_input, parse_quote,
5};
6
7enum QueryKind {
8    Read,
9    Write,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13enum ParamTypeSpec {
14    Bool,
15    I64,
16    F64,
17    F32,
18    String,
19    DateTime,
20    Bytes,
21    Value,
22    Object(Box<ParamTypeSpec>),
23    Array(Box<ParamTypeSpec>),
24}
25
26impl ParamTypeSpec {
27    fn to_tokens(&self) -> proc_macro2::TokenStream {
28        match self {
29            Self::Bool => quote! { ::helix_db::QueryParamType::Bool },
30            Self::I64 => quote! { ::helix_db::QueryParamType::I64 },
31            Self::F64 => quote! { ::helix_db::QueryParamType::F64 },
32            Self::F32 => quote! { ::helix_db::QueryParamType::F32 },
33            Self::String => quote! { ::helix_db::QueryParamType::String },
34            Self::DateTime => quote! { ::helix_db::QueryParamType::DateTime },
35            Self::Bytes => quote! { ::helix_db::QueryParamType::Bytes },
36            Self::Value => quote! { ::helix_db::QueryParamType::Value },
37            Self::Object(_) => quote! { ::helix_db::QueryParamType::Object },
38            Self::Array(inner) => {
39                let inner = inner.to_tokens();
40                quote! { ::helix_db::QueryParamType::Array(Box::new(#inner)) }
41            }
42        }
43    }
44
45    fn to_dynamic_value_tokens(
46        &self,
47        value: proc_macro2::TokenStream,
48        path: proc_macro2::TokenStream,
49        depth: usize,
50    ) -> proc_macro2::TokenStream {
51        match self {
52            Self::Bool => quote! {
53                ::std::result::Result::<
54                    ::helix_db::DynamicQueryValue,
55                    ::helix_db::DynamicQueryError,
56                >::Ok(::helix_db::DynamicQueryValue::Bool(#value))
57            },
58            Self::I64 => quote! {
59                ::std::result::Result::<
60                    ::helix_db::DynamicQueryValue,
61                    ::helix_db::DynamicQueryError,
62                >::Ok(::helix_db::DynamicQueryValue::I64(#value))
63            },
64            Self::F64 => quote! {
65                ::std::result::Result::<
66                    ::helix_db::DynamicQueryValue,
67                    ::helix_db::DynamicQueryError,
68                >::Ok(::helix_db::DynamicQueryValue::F64(#value))
69            },
70            Self::F32 => quote! {
71                ::std::result::Result::<
72                    ::helix_db::DynamicQueryValue,
73                    ::helix_db::DynamicQueryError,
74                >::Ok(::helix_db::DynamicQueryValue::F32(#value))
75            },
76            Self::String => quote! {
77                ::std::result::Result::<
78                    ::helix_db::DynamicQueryValue,
79                    ::helix_db::DynamicQueryError,
80                >::Ok(::helix_db::DynamicQueryValue::String(#value))
81            },
82            Self::DateTime => quote! {
83                ::std::result::Result::<
84                    ::helix_db::DynamicQueryValue,
85                    ::helix_db::DynamicQueryError,
86                >::Ok(::helix_db::DynamicQueryValue::String(
87                    (#value)
88                        .to_rfc3339()
89                        .ok_or_else(|| ::helix_db::DynamicQueryError::invalid_datetime(#path, (#value).millis()))?
90                ))
91            },
92            Self::Bytes => quote! {
93                ::std::result::Result::<
94                    ::helix_db::DynamicQueryValue,
95                    ::helix_db::DynamicQueryError,
96                >::Err(::helix_db::DynamicQueryError::unsupported_bytes(#path))
97            },
98            Self::Value => quote! {
99                ::helix_db::__private::dynamic_query_value_from_property_value(#value, #path)
100            },
101            Self::Object(inner) => {
102                let key_ident = format_ident!("__helix_param_key_{depth}");
103                let value_ident = format_ident!("__helix_param_value_{depth}");
104                let path_ident = format_ident!("__helix_param_path_{depth}");
105                let inner_tokens = inner.to_dynamic_value_tokens(
106                    quote! { #value_ident },
107                    quote! { #path_ident },
108                    depth + 1,
109                );
110
111                quote! {
112                    ::std::result::Result::<
113                        ::helix_db::DynamicQueryValue,
114                        ::helix_db::DynamicQueryError,
115                    >::Ok(::helix_db::DynamicQueryValue::Object(
116                        (#value)
117                            .into_iter()
118                            .map(|(#key_ident, #value_ident)| {
119                                let #path_ident = ::std::format!("{}.{}", #path, #key_ident);
120                                ::std::result::Result::Ok((#key_ident, #inner_tokens?))
121                            })
122                            .collect::<::std::result::Result<
123                                ::std::collections::BTreeMap<_, _>,
124                                ::helix_db::DynamicQueryError,
125                            >>()?,
126                    ))
127                }
128            }
129            Self::Array(inner) => {
130                let index_ident = format_ident!("__helix_param_index_{depth}");
131                let value_ident = format_ident!("__helix_param_value_{depth}");
132                let path_ident = format_ident!("__helix_param_path_{depth}");
133                let inner_tokens = inner.to_dynamic_value_tokens(
134                    quote! { #value_ident },
135                    quote! { #path_ident },
136                    depth + 1,
137                );
138
139                quote! {
140                    ::std::result::Result::<
141                        ::helix_db::DynamicQueryValue,
142                        ::helix_db::DynamicQueryError,
143                    >::Ok(::helix_db::DynamicQueryValue::Array(
144                        (#value)
145                            .into_iter()
146                            .enumerate()
147                            .map(|(#index_ident, #value_ident)| {
148                                let #path_ident = ::std::format!("{}[{}]", #path, #index_ident);
149                                #inner_tokens
150                            })
151                            .collect::<::std::result::Result<
152                                ::std::vec::Vec<_>,
153                                ::helix_db::DynamicQueryError,
154                            >>()?,
155                    ))
156                }
157            }
158        }
159    }
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
163struct ParamSpec {
164    ident: syn::Ident,
165    ty: ParamTypeSpec,
166}
167
168/// Infer whether a registered function is a read or write query by inspecting its body for a
169/// call to `read_batch()` or `write_batch()`. The return type is not used (and may be omitted).
170fn infer_query_kind(fn_item: &ItemFn) -> syn::Result<QueryKind> {
171    fn mentions(tokens: proc_macro2::TokenStream, target: &str) -> bool {
172        tokens.into_iter().any(|tt| match tt {
173            proc_macro2::TokenTree::Ident(ident) => ident == target,
174            proc_macro2::TokenTree::Group(group) => mentions(group.stream(), target),
175            _ => false,
176        })
177    }
178
179    let body = &fn_item.block;
180    let tokens = quote! { #body };
181    if mentions(tokens.clone(), "write_batch") {
182        Ok(QueryKind::Write)
183    } else if mentions(tokens, "read_batch") {
184        Ok(QueryKind::Read)
185    } else {
186        Err(syn::Error::new_spanned(
187            &fn_item.sig,
188            "could not infer query kind: function body must call `read_batch()` or `write_batch()`",
189        ))
190    }
191}
192
193const TYPE_ERROR_MSG: &str = "\
194#[register] parameter type must be a supported query parameter type: \
195bool, i64, f64, f32, String, DateTime, Vec<u8>, PropertyValue, ParamValue, ParamObject, \
196Vec<T> for supported T, or BTreeMap<String, T>/HashMap<String, T> for supported T";
197
198fn ensure_no_args(segment: &syn::PathSegment, ty: &Type) -> syn::Result<()> {
199    if matches!(segment.arguments, PathArguments::None) {
200        Ok(())
201    } else {
202        Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG))
203    }
204}
205
206fn single_type_arg<'a>(segment: &'a syn::PathSegment, ty: &Type) -> syn::Result<&'a Type> {
207    let PathArguments::AngleBracketed(args) = &segment.arguments else {
208        return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
209    };
210    if args.args.len() != 1 {
211        return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
212    }
213    match args.args.first() {
214        Some(GenericArgument::Type(inner)) => Ok(inner),
215        _ => Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
216    }
217}
218
219fn two_type_args<'a>(
220    segment: &'a syn::PathSegment,
221    ty: &Type,
222) -> syn::Result<(&'a Type, &'a Type)> {
223    let PathArguments::AngleBracketed(args) = &segment.arguments else {
224        return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
225    };
226    if args.args.len() != 2 {
227        return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
228    }
229    let first = match args.args.first() {
230        Some(GenericArgument::Type(inner)) => inner,
231        _ => return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
232    };
233    let second = match args.args.iter().nth(1) {
234        Some(GenericArgument::Type(inner)) => inner,
235        _ => return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
236    };
237    Ok((first, second))
238}
239
240fn is_string_type(ty: &Type) -> bool {
241    let Type::Path(type_path) = ty else {
242        return false;
243    };
244    let Some(segment) = type_path.path.segments.last() else {
245        return false;
246    };
247    segment.ident == "String" && matches!(segment.arguments, PathArguments::None)
248}
249
250/// Parse a supported registered parameter type into an owned schema shape.
251fn parse_param_type(ty: &Type) -> syn::Result<ParamTypeSpec> {
252    let Type::Path(type_path) = ty else {
253        return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
254    };
255
256    let Some(segment) = type_path.path.segments.last() else {
257        return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
258    };
259
260    let type_name = segment.ident.to_string();
261
262    match type_name.as_str() {
263        "bool" => {
264            ensure_no_args(segment, ty)?;
265            Ok(ParamTypeSpec::Bool)
266        }
267        "i64" => {
268            ensure_no_args(segment, ty)?;
269            Ok(ParamTypeSpec::I64)
270        }
271        "f64" => {
272            ensure_no_args(segment, ty)?;
273            Ok(ParamTypeSpec::F64)
274        }
275        "f32" => {
276            ensure_no_args(segment, ty)?;
277            Ok(ParamTypeSpec::F32)
278        }
279        "String" => {
280            ensure_no_args(segment, ty)?;
281            Ok(ParamTypeSpec::String)
282        }
283        "DateTime" => {
284            ensure_no_args(segment, ty)?;
285            Ok(ParamTypeSpec::DateTime)
286        }
287        "PropertyValue" | "ParamValue" => {
288            ensure_no_args(segment, ty)?;
289            Ok(ParamTypeSpec::Value)
290        }
291        "ParamObject" => {
292            ensure_no_args(segment, ty)?;
293            Ok(ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value)))
294        }
295        "Vec" => {
296            let inner = single_type_arg(segment, ty)?;
297            if let Type::Path(inner_path) = inner {
298                if let Some(inner_seg) = inner_path.path.segments.last() {
299                    if inner_seg.ident == "u8" && matches!(inner_seg.arguments, PathArguments::None)
300                    {
301                        return Ok(ParamTypeSpec::Bytes);
302                    }
303                }
304            }
305            Ok(ParamTypeSpec::Array(Box::new(parse_param_type(inner)?)))
306        }
307        "BTreeMap" | "HashMap" => {
308            let (key_ty, value_ty) = two_type_args(segment, ty)?;
309            if !is_string_type(key_ty) {
310                return Err(syn::Error::new_spanned(key_ty, TYPE_ERROR_MSG));
311            }
312            Ok(ParamTypeSpec::Object(Box::new(parse_param_type(value_ty)?)))
313        }
314        _ => Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
315    }
316}
317
318/// Extract and validate parameter declarations from function arguments.
319fn extract_param_specs(fn_item: &ItemFn) -> syn::Result<Vec<ParamSpec>> {
320    let mut params = Vec::new();
321    for arg in &fn_item.sig.inputs {
322        match arg {
323            FnArg::Receiver(recv) => {
324                return Err(syn::Error::new_spanned(
325                    recv,
326                    "#[register] functions cannot take self",
327                ));
328            }
329            FnArg::Typed(pat_type) => {
330                if let Pat::Ident(pat_ident) = &*pat_type.pat {
331                    params.push(ParamSpec {
332                        ident: pat_ident.ident.clone(),
333                        ty: parse_param_type(&pat_type.ty)?,
334                    });
335                } else {
336                    return Err(syn::Error::new_spanned(
337                        &pat_type.pat,
338                        "#[register] function parameters must be simple identifiers",
339                    ));
340                }
341            }
342        }
343    }
344    Ok(params)
345}
346
347#[proc_macro_attribute]
348pub fn register(attr: TokenStream, item: TokenStream) -> TokenStream {
349    if !attr.is_empty() {
350        return syn::Error::new(
351            proc_macro2::Span::call_site(),
352            "#[register] does not accept arguments",
353        )
354        .to_compile_error()
355        .into();
356    }
357
358    let fn_item = parse_macro_input!(item as ItemFn);
359
360    if fn_item.sig.asyncness.is_some() {
361        return syn::Error::new_spanned(&fn_item.sig, "#[register] functions cannot be async")
362            .to_compile_error()
363            .into();
364    }
365
366    if !fn_item.sig.generics.params.is_empty() {
367        return syn::Error::new_spanned(
368            &fn_item.sig.generics,
369            "#[register] functions cannot be generic",
370        )
371        .to_compile_error()
372        .into();
373    }
374
375    let query_kind = match infer_query_kind(&fn_item) {
376        Ok(kind) => kind,
377        Err(err) => return err.to_compile_error().into(),
378    };
379
380    let param_specs = match extract_param_specs(&fn_item) {
381        Ok(params) => params,
382        Err(err) => return err.to_compile_error().into(),
383    };
384
385    let fn_name = fn_item.sig.ident.clone();
386    let fn_attrs = fn_item.attrs.clone();
387    let fn_visibility = fn_item.vis.clone();
388    let fn_body = &fn_item.block;
389    let params_fn_name = format_ident!("__helix_dsl_params_{}", fn_name);
390
391    // Generate `let name = Expr::param("name");` bindings for each parameter
392    let param_name_strs: Vec<String> = param_specs
393        .iter()
394        .map(|param| param.ident.to_string())
395        .collect();
396    let let_bindings = param_specs
397        .iter()
398        .zip(param_name_strs.iter())
399        .map(|(param, name_str)| {
400            let ident = &param.ident;
401            quote! {
402                let #ident = ::helix_db::Expr::param(#name_str);
403            }
404        });
405
406    let parameter_entries = param_specs
407        .iter()
408        .zip(param_name_strs.iter())
409        .map(|(param, name)| {
410            let ty = param.ty.to_tokens();
411            quote! {
412                ::helix_db::QueryParameter {
413                    name: #name.to_string(),
414                    ty: #ty,
415                }
416            }
417        });
418
419    let parameters_fn = quote! {
420        #[allow(non_snake_case)]
421        fn #params_fn_name() -> ::std::vec::Vec<::helix_db::QueryParameter> {
422            vec![#(#parameter_entries),*]
423        }
424    };
425
426    let decomposed_fn_name = format_ident!("{}_decomposed", fn_name);
427
428    // Decompose: strip params, prepend let bindings to body
429    let decomposed_fn = match query_kind {
430        QueryKind::Read => quote! {
431            fn #decomposed_fn_name() -> ::helix_db::ReadBatch {
432                #(#let_bindings)*
433                #fn_body
434            }
435        },
436        QueryKind::Write => quote! {
437            fn #decomposed_fn_name() -> ::helix_db::WriteBatch {
438                #(#let_bindings)*
439                #fn_body
440            }
441        },
442    };
443
444    // The user-callable function always returns a bare `DynamicQueryRequest`, regardless of
445    // visibility: calling `query1(args)` builds the request directly. Parameter coercion that can
446    // fail (DateTime, bytes, Value, collections) is wrapped in a closure and `.expect()`-ed so the
447    // `?` inside `to_dynamic_value_tokens` has a `Result` context; infallible scalars never panic.
448    let callable_fn = {
449        let mut request_sig = fn_item.sig.clone();
450        request_sig.output = parse_quote!(-> ::helix_db::DynamicQueryRequest);
451        let request_ctor = match query_kind {
452            QueryKind::Read => quote! { ::helix_db::DynamicQueryRequest::read },
453            QueryKind::Write => quote! { ::helix_db::DynamicQueryRequest::write },
454        };
455        let request_param_inserts =
456            param_specs
457                .iter()
458                .zip(param_name_strs.iter())
459                .map(|(param, name)| {
460                    let ident = &param.ident;
461                    let value_tokens =
462                        param
463                            .ty
464                            .to_dynamic_value_tokens(quote! { #ident }, quote! { #name }, 0);
465                    let type_tokens = param.ty.to_tokens();
466                    let expect_msg = format!("failed to coerce parameter `{name}`");
467                    quote! {
468                        request.insert_parameter_value(
469                            #name,
470                            (|| -> ::std::result::Result<
471                                ::helix_db::DynamicQueryValue,
472                                ::helix_db::DynamicQueryError,
473                            > { #value_tokens })()
474                            .expect(#expect_msg),
475                        );
476                        request.insert_parameter_type(#name, #type_tokens);
477                    }
478                });
479
480        quote! {
481            #(#fn_attrs)*
482            #fn_visibility #request_sig {
483                let mut request = #request_ctor(#decomposed_fn_name());
484                #(#request_param_inserts)*
485                request
486            }
487        }
488    };
489
490    let submit_item = match query_kind {
491        QueryKind::Read => {
492            quote! {
493                ::helix_db::__private::inventory::submit! {
494                    ::helix_db::RegisteredReadQuery {
495                        name: stringify!(#fn_name),
496                        build: #decomposed_fn_name,
497                        parameters: #params_fn_name,
498                    }
499                }
500            }
501        }
502        QueryKind::Write => {
503            quote! {
504                ::helix_db::__private::inventory::submit! {
505                    ::helix_db::RegisteredWriteQuery {
506                        name: stringify!(#fn_name),
507                        build: #decomposed_fn_name,
508                        parameters: #params_fn_name,
509                    }
510                }
511            }
512        }
513    };
514
515    quote! {
516        #callable_fn
517        #decomposed_fn
518        #parameters_fn
519        #submit_item
520    }
521    .into()
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use syn::{Type, parse_str};
528
529    fn parse_type(input: &str) -> ParamTypeSpec {
530        let ty: Type = parse_str(input).expect("parse type");
531        parse_param_type(&ty).expect("supported param type")
532    }
533
534    #[test]
535    fn accepts_nested_batch_object_types() {
536        assert_eq!(
537            parse_type("ParamObject"),
538            ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
539        );
540        assert_eq!(
541            parse_type("Vec<ParamObject>"),
542            ParamTypeSpec::Array(Box::new(ParamTypeSpec::Object(Box::new(
543                ParamTypeSpec::Value
544            ))))
545        );
546        assert_eq!(
547            parse_type("Vec<Vec<ParamObject>>"),
548            ParamTypeSpec::Array(Box::new(ParamTypeSpec::Array(Box::new(
549                ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
550            ))))
551        );
552    }
553
554    #[test]
555    fn accepts_property_value_aliases_and_maps() {
556        assert_eq!(parse_type("PropertyValue"), ParamTypeSpec::Value);
557        assert_eq!(parse_type("ParamValue"), ParamTypeSpec::Value);
558        assert_eq!(
559            parse_type("BTreeMap<String, PropertyValue>"),
560            ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
561        );
562        assert_eq!(
563            parse_type("std::collections::HashMap<String, ParamValue>"),
564            ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
565        );
566        assert_eq!(
567            parse_type("BTreeMap<String, String>"),
568            ParamTypeSpec::Object(Box::new(ParamTypeSpec::String))
569        );
570    }
571
572    #[test]
573    fn accepts_existing_scalar_and_array_types() {
574        assert_eq!(parse_type("bool"), ParamTypeSpec::Bool);
575        assert_eq!(parse_type("i64"), ParamTypeSpec::I64);
576        assert_eq!(parse_type("DateTime"), ParamTypeSpec::DateTime);
577        assert_eq!(parse_type("Vec<u8>"), ParamTypeSpec::Bytes);
578        assert_eq!(
579            parse_type("Vec<String>"),
580            ParamTypeSpec::Array(Box::new(ParamTypeSpec::String))
581        );
582    }
583
584    #[test]
585    fn rejects_unsupported_types() {
586        let ty: Type = parse_str("UserBatchRow").expect("parse type");
587        assert!(parse_param_type(&ty).is_err());
588
589        let ty: Type = parse_str("Vec<UserBatchRow>").expect("parse type");
590        assert!(parse_param_type(&ty).is_err());
591
592        let ty: Type = parse_str("BTreeMap<i64, String>").expect("parse type");
593        assert!(parse_param_type(&ty).is_err());
594    }
595}