1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 FnArg, GenericArgument, ItemFn, Pat, PathArguments, Type, parse_macro_input, parse_quote,
5};
6
7enum QueryKind {
8 Read,
9 Write,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13enum ParamTypeSpec {
14 Bool,
15 I64,
16 F64,
17 F32,
18 String,
19 DateTime,
20 Bytes,
21 Value,
22 Object(Box<ParamTypeSpec>),
23 Array(Box<ParamTypeSpec>),
24}
25
26impl ParamTypeSpec {
27 fn to_tokens(&self) -> proc_macro2::TokenStream {
28 match self {
29 Self::Bool => quote! { ::helix_db::QueryParamType::Bool },
30 Self::I64 => quote! { ::helix_db::QueryParamType::I64 },
31 Self::F64 => quote! { ::helix_db::QueryParamType::F64 },
32 Self::F32 => quote! { ::helix_db::QueryParamType::F32 },
33 Self::String => quote! { ::helix_db::QueryParamType::String },
34 Self::DateTime => quote! { ::helix_db::QueryParamType::DateTime },
35 Self::Bytes => quote! { ::helix_db::QueryParamType::Bytes },
36 Self::Value => quote! { ::helix_db::QueryParamType::Value },
37 Self::Object(_) => quote! { ::helix_db::QueryParamType::Object },
38 Self::Array(inner) => {
39 let inner = inner.to_tokens();
40 quote! { ::helix_db::QueryParamType::Array(Box::new(#inner)) }
41 }
42 }
43 }
44
45 fn to_dynamic_value_tokens(
46 &self,
47 value: proc_macro2::TokenStream,
48 path: proc_macro2::TokenStream,
49 depth: usize,
50 ) -> proc_macro2::TokenStream {
51 match self {
52 Self::Bool => quote! {
53 ::std::result::Result::<
54 ::helix_db::DynamicQueryValue,
55 ::helix_db::DynamicQueryError,
56 >::Ok(::helix_db::DynamicQueryValue::Bool(#value))
57 },
58 Self::I64 => quote! {
59 ::std::result::Result::<
60 ::helix_db::DynamicQueryValue,
61 ::helix_db::DynamicQueryError,
62 >::Ok(::helix_db::DynamicQueryValue::I64(#value))
63 },
64 Self::F64 => quote! {
65 ::std::result::Result::<
66 ::helix_db::DynamicQueryValue,
67 ::helix_db::DynamicQueryError,
68 >::Ok(::helix_db::DynamicQueryValue::F64(#value))
69 },
70 Self::F32 => quote! {
71 ::std::result::Result::<
72 ::helix_db::DynamicQueryValue,
73 ::helix_db::DynamicQueryError,
74 >::Ok(::helix_db::DynamicQueryValue::F32(#value))
75 },
76 Self::String => quote! {
77 ::std::result::Result::<
78 ::helix_db::DynamicQueryValue,
79 ::helix_db::DynamicQueryError,
80 >::Ok(::helix_db::DynamicQueryValue::String(#value))
81 },
82 Self::DateTime => quote! {
83 ::std::result::Result::<
84 ::helix_db::DynamicQueryValue,
85 ::helix_db::DynamicQueryError,
86 >::Ok(::helix_db::DynamicQueryValue::String(
87 (#value)
88 .to_rfc3339()
89 .ok_or_else(|| ::helix_db::DynamicQueryError::invalid_datetime(#path, (#value).millis()))?
90 ))
91 },
92 Self::Bytes => quote! {
93 ::std::result::Result::<
94 ::helix_db::DynamicQueryValue,
95 ::helix_db::DynamicQueryError,
96 >::Err(::helix_db::DynamicQueryError::unsupported_bytes(#path))
97 },
98 Self::Value => quote! {
99 ::helix_db::__private::dynamic_query_value_from_property_value(#value, #path)
100 },
101 Self::Object(inner) => {
102 let key_ident = format_ident!("__helix_param_key_{depth}");
103 let value_ident = format_ident!("__helix_param_value_{depth}");
104 let path_ident = format_ident!("__helix_param_path_{depth}");
105 let inner_tokens = inner.to_dynamic_value_tokens(
106 quote! { #value_ident },
107 quote! { #path_ident },
108 depth + 1,
109 );
110
111 quote! {
112 ::std::result::Result::<
113 ::helix_db::DynamicQueryValue,
114 ::helix_db::DynamicQueryError,
115 >::Ok(::helix_db::DynamicQueryValue::Object(
116 (#value)
117 .into_iter()
118 .map(|(#key_ident, #value_ident)| {
119 let #path_ident = ::std::format!("{}.{}", #path, #key_ident);
120 ::std::result::Result::Ok((#key_ident, #inner_tokens?))
121 })
122 .collect::<::std::result::Result<
123 ::std::collections::BTreeMap<_, _>,
124 ::helix_db::DynamicQueryError,
125 >>()?,
126 ))
127 }
128 }
129 Self::Array(inner) => {
130 let index_ident = format_ident!("__helix_param_index_{depth}");
131 let value_ident = format_ident!("__helix_param_value_{depth}");
132 let path_ident = format_ident!("__helix_param_path_{depth}");
133 let inner_tokens = inner.to_dynamic_value_tokens(
134 quote! { #value_ident },
135 quote! { #path_ident },
136 depth + 1,
137 );
138
139 quote! {
140 ::std::result::Result::<
141 ::helix_db::DynamicQueryValue,
142 ::helix_db::DynamicQueryError,
143 >::Ok(::helix_db::DynamicQueryValue::Array(
144 (#value)
145 .into_iter()
146 .enumerate()
147 .map(|(#index_ident, #value_ident)| {
148 let #path_ident = ::std::format!("{}[{}]", #path, #index_ident);
149 #inner_tokens
150 })
151 .collect::<::std::result::Result<
152 ::std::vec::Vec<_>,
153 ::helix_db::DynamicQueryError,
154 >>()?,
155 ))
156 }
157 }
158 }
159 }
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
163struct ParamSpec {
164 ident: syn::Ident,
165 ty: ParamTypeSpec,
166}
167
168fn infer_query_kind(fn_item: &ItemFn) -> syn::Result<QueryKind> {
171 fn mentions(tokens: proc_macro2::TokenStream, target: &str) -> bool {
172 tokens.into_iter().any(|tt| match tt {
173 proc_macro2::TokenTree::Ident(ident) => ident == target,
174 proc_macro2::TokenTree::Group(group) => mentions(group.stream(), target),
175 _ => false,
176 })
177 }
178
179 let body = &fn_item.block;
180 let tokens = quote! { #body };
181 if mentions(tokens.clone(), "write_batch") {
182 Ok(QueryKind::Write)
183 } else if mentions(tokens, "read_batch") {
184 Ok(QueryKind::Read)
185 } else {
186 Err(syn::Error::new_spanned(
187 &fn_item.sig,
188 "could not infer query kind: function body must call `read_batch()` or `write_batch()`",
189 ))
190 }
191}
192
193const TYPE_ERROR_MSG: &str = "\
194#[register] parameter type must be a supported query parameter type: \
195bool, i64, f64, f32, String, DateTime, Vec<u8>, PropertyValue, ParamValue, ParamObject, \
196Vec<T> for supported T, or BTreeMap<String, T>/HashMap<String, T> for supported T";
197
198fn ensure_no_args(segment: &syn::PathSegment, ty: &Type) -> syn::Result<()> {
199 if matches!(segment.arguments, PathArguments::None) {
200 Ok(())
201 } else {
202 Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG))
203 }
204}
205
206fn single_type_arg<'a>(segment: &'a syn::PathSegment, ty: &Type) -> syn::Result<&'a Type> {
207 let PathArguments::AngleBracketed(args) = &segment.arguments else {
208 return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
209 };
210 if args.args.len() != 1 {
211 return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
212 }
213 match args.args.first() {
214 Some(GenericArgument::Type(inner)) => Ok(inner),
215 _ => Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
216 }
217}
218
219fn two_type_args<'a>(
220 segment: &'a syn::PathSegment,
221 ty: &Type,
222) -> syn::Result<(&'a Type, &'a Type)> {
223 let PathArguments::AngleBracketed(args) = &segment.arguments else {
224 return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
225 };
226 if args.args.len() != 2 {
227 return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
228 }
229 let first = match args.args.first() {
230 Some(GenericArgument::Type(inner)) => inner,
231 _ => return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
232 };
233 let second = match args.args.iter().nth(1) {
234 Some(GenericArgument::Type(inner)) => inner,
235 _ => return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
236 };
237 Ok((first, second))
238}
239
240fn is_string_type(ty: &Type) -> bool {
241 let Type::Path(type_path) = ty else {
242 return false;
243 };
244 let Some(segment) = type_path.path.segments.last() else {
245 return false;
246 };
247 segment.ident == "String" && matches!(segment.arguments, PathArguments::None)
248}
249
250fn parse_param_type(ty: &Type) -> syn::Result<ParamTypeSpec> {
252 let Type::Path(type_path) = ty else {
253 return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
254 };
255
256 let Some(segment) = type_path.path.segments.last() else {
257 return Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG));
258 };
259
260 let type_name = segment.ident.to_string();
261
262 match type_name.as_str() {
263 "bool" => {
264 ensure_no_args(segment, ty)?;
265 Ok(ParamTypeSpec::Bool)
266 }
267 "i64" => {
268 ensure_no_args(segment, ty)?;
269 Ok(ParamTypeSpec::I64)
270 }
271 "f64" => {
272 ensure_no_args(segment, ty)?;
273 Ok(ParamTypeSpec::F64)
274 }
275 "f32" => {
276 ensure_no_args(segment, ty)?;
277 Ok(ParamTypeSpec::F32)
278 }
279 "String" => {
280 ensure_no_args(segment, ty)?;
281 Ok(ParamTypeSpec::String)
282 }
283 "DateTime" => {
284 ensure_no_args(segment, ty)?;
285 Ok(ParamTypeSpec::DateTime)
286 }
287 "PropertyValue" | "ParamValue" => {
288 ensure_no_args(segment, ty)?;
289 Ok(ParamTypeSpec::Value)
290 }
291 "ParamObject" => {
292 ensure_no_args(segment, ty)?;
293 Ok(ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value)))
294 }
295 "Vec" => {
296 let inner = single_type_arg(segment, ty)?;
297 if let Type::Path(inner_path) = inner {
298 if let Some(inner_seg) = inner_path.path.segments.last() {
299 if inner_seg.ident == "u8" && matches!(inner_seg.arguments, PathArguments::None)
300 {
301 return Ok(ParamTypeSpec::Bytes);
302 }
303 }
304 }
305 Ok(ParamTypeSpec::Array(Box::new(parse_param_type(inner)?)))
306 }
307 "BTreeMap" | "HashMap" => {
308 let (key_ty, value_ty) = two_type_args(segment, ty)?;
309 if !is_string_type(key_ty) {
310 return Err(syn::Error::new_spanned(key_ty, TYPE_ERROR_MSG));
311 }
312 Ok(ParamTypeSpec::Object(Box::new(parse_param_type(value_ty)?)))
313 }
314 _ => Err(syn::Error::new_spanned(ty, TYPE_ERROR_MSG)),
315 }
316}
317
318fn extract_param_specs(fn_item: &ItemFn) -> syn::Result<Vec<ParamSpec>> {
320 let mut params = Vec::new();
321 for arg in &fn_item.sig.inputs {
322 match arg {
323 FnArg::Receiver(recv) => {
324 return Err(syn::Error::new_spanned(
325 recv,
326 "#[register] functions cannot take self",
327 ));
328 }
329 FnArg::Typed(pat_type) => {
330 if let Pat::Ident(pat_ident) = &*pat_type.pat {
331 params.push(ParamSpec {
332 ident: pat_ident.ident.clone(),
333 ty: parse_param_type(&pat_type.ty)?,
334 });
335 } else {
336 return Err(syn::Error::new_spanned(
337 &pat_type.pat,
338 "#[register] function parameters must be simple identifiers",
339 ));
340 }
341 }
342 }
343 }
344 Ok(params)
345}
346
347#[proc_macro_attribute]
348pub fn register(attr: TokenStream, item: TokenStream) -> TokenStream {
349 if !attr.is_empty() {
350 return syn::Error::new(
351 proc_macro2::Span::call_site(),
352 "#[register] does not accept arguments",
353 )
354 .to_compile_error()
355 .into();
356 }
357
358 let fn_item = parse_macro_input!(item as ItemFn);
359
360 if fn_item.sig.asyncness.is_some() {
361 return syn::Error::new_spanned(&fn_item.sig, "#[register] functions cannot be async")
362 .to_compile_error()
363 .into();
364 }
365
366 if !fn_item.sig.generics.params.is_empty() {
367 return syn::Error::new_spanned(
368 &fn_item.sig.generics,
369 "#[register] functions cannot be generic",
370 )
371 .to_compile_error()
372 .into();
373 }
374
375 let query_kind = match infer_query_kind(&fn_item) {
376 Ok(kind) => kind,
377 Err(err) => return err.to_compile_error().into(),
378 };
379
380 let param_specs = match extract_param_specs(&fn_item) {
381 Ok(params) => params,
382 Err(err) => return err.to_compile_error().into(),
383 };
384
385 let fn_name = fn_item.sig.ident.clone();
386 let fn_attrs = fn_item.attrs.clone();
387 let fn_visibility = fn_item.vis.clone();
388 let fn_body = &fn_item.block;
389 let params_fn_name = format_ident!("__helix_dsl_params_{}", fn_name);
390
391 let param_name_strs: Vec<String> = param_specs
393 .iter()
394 .map(|param| param.ident.to_string())
395 .collect();
396 let let_bindings = param_specs
397 .iter()
398 .zip(param_name_strs.iter())
399 .map(|(param, name_str)| {
400 let ident = ¶m.ident;
401 quote! {
402 let #ident = ::helix_db::Expr::param(#name_str);
403 }
404 });
405
406 let parameter_entries = param_specs
407 .iter()
408 .zip(param_name_strs.iter())
409 .map(|(param, name)| {
410 let ty = param.ty.to_tokens();
411 quote! {
412 ::helix_db::QueryParameter {
413 name: #name.to_string(),
414 ty: #ty,
415 }
416 }
417 });
418
419 let parameters_fn = quote! {
420 #[allow(non_snake_case)]
421 fn #params_fn_name() -> ::std::vec::Vec<::helix_db::QueryParameter> {
422 vec![#(#parameter_entries),*]
423 }
424 };
425
426 let decomposed_fn_name = format_ident!("{}_decomposed", fn_name);
427
428 let decomposed_fn = match query_kind {
430 QueryKind::Read => quote! {
431 fn #decomposed_fn_name() -> ::helix_db::ReadBatch {
432 #(#let_bindings)*
433 #fn_body
434 }
435 },
436 QueryKind::Write => quote! {
437 fn #decomposed_fn_name() -> ::helix_db::WriteBatch {
438 #(#let_bindings)*
439 #fn_body
440 }
441 },
442 };
443
444 let callable_fn = {
449 let mut request_sig = fn_item.sig.clone();
450 request_sig.output = parse_quote!(-> ::helix_db::DynamicQueryRequest);
451 let request_ctor = match query_kind {
452 QueryKind::Read => quote! { ::helix_db::DynamicQueryRequest::read },
453 QueryKind::Write => quote! { ::helix_db::DynamicQueryRequest::write },
454 };
455 let request_param_inserts =
456 param_specs
457 .iter()
458 .zip(param_name_strs.iter())
459 .map(|(param, name)| {
460 let ident = ¶m.ident;
461 let value_tokens =
462 param
463 .ty
464 .to_dynamic_value_tokens(quote! { #ident }, quote! { #name }, 0);
465 let type_tokens = param.ty.to_tokens();
466 let expect_msg = format!("failed to coerce parameter `{name}`");
467 quote! {
468 request.insert_parameter_value(
469 #name,
470 (|| -> ::std::result::Result<
471 ::helix_db::DynamicQueryValue,
472 ::helix_db::DynamicQueryError,
473 > { #value_tokens })()
474 .expect(#expect_msg),
475 );
476 request.insert_parameter_type(#name, #type_tokens);
477 }
478 });
479
480 quote! {
481 #(#fn_attrs)*
482 #fn_visibility #request_sig {
483 let mut request = #request_ctor(#decomposed_fn_name());
484 #(#request_param_inserts)*
485 request
486 }
487 }
488 };
489
490 let submit_item = match query_kind {
491 QueryKind::Read => {
492 quote! {
493 ::helix_db::__private::inventory::submit! {
494 ::helix_db::RegisteredReadQuery {
495 name: stringify!(#fn_name),
496 build: #decomposed_fn_name,
497 parameters: #params_fn_name,
498 }
499 }
500 }
501 }
502 QueryKind::Write => {
503 quote! {
504 ::helix_db::__private::inventory::submit! {
505 ::helix_db::RegisteredWriteQuery {
506 name: stringify!(#fn_name),
507 build: #decomposed_fn_name,
508 parameters: #params_fn_name,
509 }
510 }
511 }
512 }
513 };
514
515 quote! {
516 #callable_fn
517 #decomposed_fn
518 #parameters_fn
519 #submit_item
520 }
521 .into()
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use syn::{Type, parse_str};
528
529 fn parse_type(input: &str) -> ParamTypeSpec {
530 let ty: Type = parse_str(input).expect("parse type");
531 parse_param_type(&ty).expect("supported param type")
532 }
533
534 #[test]
535 fn accepts_nested_batch_object_types() {
536 assert_eq!(
537 parse_type("ParamObject"),
538 ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
539 );
540 assert_eq!(
541 parse_type("Vec<ParamObject>"),
542 ParamTypeSpec::Array(Box::new(ParamTypeSpec::Object(Box::new(
543 ParamTypeSpec::Value
544 ))))
545 );
546 assert_eq!(
547 parse_type("Vec<Vec<ParamObject>>"),
548 ParamTypeSpec::Array(Box::new(ParamTypeSpec::Array(Box::new(
549 ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
550 ))))
551 );
552 }
553
554 #[test]
555 fn accepts_property_value_aliases_and_maps() {
556 assert_eq!(parse_type("PropertyValue"), ParamTypeSpec::Value);
557 assert_eq!(parse_type("ParamValue"), ParamTypeSpec::Value);
558 assert_eq!(
559 parse_type("BTreeMap<String, PropertyValue>"),
560 ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
561 );
562 assert_eq!(
563 parse_type("std::collections::HashMap<String, ParamValue>"),
564 ParamTypeSpec::Object(Box::new(ParamTypeSpec::Value))
565 );
566 assert_eq!(
567 parse_type("BTreeMap<String, String>"),
568 ParamTypeSpec::Object(Box::new(ParamTypeSpec::String))
569 );
570 }
571
572 #[test]
573 fn accepts_existing_scalar_and_array_types() {
574 assert_eq!(parse_type("bool"), ParamTypeSpec::Bool);
575 assert_eq!(parse_type("i64"), ParamTypeSpec::I64);
576 assert_eq!(parse_type("DateTime"), ParamTypeSpec::DateTime);
577 assert_eq!(parse_type("Vec<u8>"), ParamTypeSpec::Bytes);
578 assert_eq!(
579 parse_type("Vec<String>"),
580 ParamTypeSpec::Array(Box::new(ParamTypeSpec::String))
581 );
582 }
583
584 #[test]
585 fn rejects_unsupported_types() {
586 let ty: Type = parse_str("UserBatchRow").expect("parse type");
587 assert!(parse_param_type(&ty).is_err());
588
589 let ty: Type = parse_str("Vec<UserBatchRow>").expect("parse type");
590 assert!(parse_param_type(&ty).is_err());
591
592 let ty: Type = parse_str("BTreeMap<i64, String>").expect("parse type");
593 assert!(parse_param_type(&ty).is_err());
594 }
595}