surrealqlx_macros_impl/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::{Data, DeriveInput, ExprAssign, ExprLit, parse::Parse, punctuated::Punctuated};
4
5#[cfg(test)]
6mod tests;
7
8/// Implementation of the Table derive macro
9///
10/// # Errors
11///
12/// This function will return an error if the input couldn't be parsed, or if attributes are missing or invalid.
13pub fn table_macro_impl(input: TokenStream) -> syn::Result<TokenStream> {
14    let input = syn::parse2::<DeriveInput>(input)?;
15
16    let struct_name = &input.ident;
17
18    let table_name = parse_table_name(&input)?;
19
20    let struct_fields = parse_struct_fields(&input)?;
21
22    let (table_field_queries, index_queries) = parse_attributes(struct_fields, &table_name)?;
23
24    let table_query = format!("DEFINE TABLE {table_name} SCHEMAFULL;");
25
26    let table_field_queries = table_field_queries.iter().map(|q| quote! {.query(#q)});
27    let index_queries = index_queries.iter().map(|q| quote! {.query(#q)});
28
29    // Build the output, possibly using the input
30    let expanded = quote! {
31        // The generated impl goes here
32        impl ::surrealqlx::traits::Table for #struct_name {
33            const TABLE_NAME: &'static str = #table_name;
34
35            #[allow(manual_async_fn)]
36            fn init_table<C: ::surrealdb::Connection>(
37                db: &::surrealdb::Surreal<C>,
38            ) -> impl ::std::future::Future<Output = ::surrealdb::Result<()>> + Send {
39                async {
40                    let _ = db.query("BEGIN;")
41                        .query(#table_query)
42                        .query("COMMIT;")
43                        .query("BEGIN;")
44                        #(
45                            #table_field_queries
46                        )*
47                        .query("COMMIT;")
48                        .query("BEGIN;")
49                        #(
50                            #index_queries
51                        )*
52                        .query("COMMIT;").await?;
53                    Ok(())
54                }
55            }
56        }
57    };
58
59    // Hand the output tokens back to the compiler
60    Ok(expanded)
61}
62
63fn parse_table_name(input: &DeriveInput) -> syn::Result<String> {
64    let table_name = input
65        .attrs
66        .iter()
67        .find(|attr| attr.path().is_ident("Table"))
68        .ok_or_else(|| {
69            syn::Error::new_spanned(input, "Table attribute must be specified for the struct")
70        })
71        .and_then(|attr| attr.parse_args::<syn::LitStr>().map(|lit| lit.value()))?;
72    Ok(table_name)
73}
74
75/// Get's the fields of the struct
76fn parse_struct_fields(input: &DeriveInput) -> syn::Result<impl Iterator<Item = &syn::Field>> {
77    match input.data {
78        Data::Struct(ref data) => match data.fields {
79            syn::Fields::Named(ref fields) => {
80                let mut fields = fields.named.iter().peekable();
81                if fields.peek().is_none() {
82                    return Err(syn::Error::new_spanned(
83                        input,
84                        "Struct must have at least one field",
85                    ));
86                }
87                Ok(fields)
88            }
89            _ => Err(syn::Error::new_spanned(
90                input,
91                "Tuple structs not supported",
92            )),
93        },
94        _ => Err(syn::Error::new_spanned(input, "Only structs are supported")),
95    }
96}
97
98/// Parses the `#[field]` and `#[index]` attributes on the fields of the struct
99fn parse_attributes<'a>(
100    fields: impl Iterator<Item = &'a syn::Field>,
101    table_name: &str,
102) -> syn::Result<(Vec<String>, Vec<String>)> {
103    let mut table_field_queries = Vec::new();
104
105    let mut index_queries = Vec::new();
106
107    for field in fields {
108        let Some(field_name) = field.ident.as_ref() else {
109            return Err(syn::Error::new_spanned(
110                field,
111                "Field must have a name, tuple structs not allowed",
112            ));
113        };
114        let mut field_attrs = field
115            .attrs
116            .iter()
117            .filter(|attr| attr.path().is_ident("field"))
118            .map(|attr| {
119                let parsed = attr.parse_args::<FieldAnnotation>();
120                match parsed {
121                    Ok(parsed) => Ok((attr, parsed)),
122                    Err(err) => Err(err),
123                }
124            })
125            .peekable();
126
127        // process the field attribute
128
129        // what (if anything) should be appended to the plain field definition query
130        let extra = match field_attrs.next() {
131            Some(Ok((_, FieldAnnotation::Skip))) => {
132                continue;
133            }
134            Some(Ok((_, FieldAnnotation::Plain))) => String::new(),
135            Some(Ok((_, FieldAnnotation::Typed { type_ }))) => format!(" TYPE {}", type_.value()),
136            Some(Ok((_, FieldAnnotation::CustomQuery { query }))) => {
137                format!(" {}", query.value())
138            }
139            Some(Err(err)) => {
140                return Err(err);
141            }
142            None => {
143                return Err(syn::Error::new_spanned(
144                    field,
145                    "Field must have a #[field] attribute",
146                ));
147            }
148        };
149        // next, make sure there was only one field attribute
150        if field_attrs.peek().is_some() {
151            return Err(syn::Error::new_spanned(
152                field,
153                "Field can have only one #[field] attribute",
154            ));
155        }
156
157        table_field_queries.push(format!("DEFINE FIELD {field_name} ON {table_name}{extra};",));
158
159        // next, we process the index attribute(s)
160        let index_attrs = field
161            .attrs
162            .iter()
163            .filter(|attr| attr.path().is_ident("index"))
164            .map(|attr| {
165                let parsed = attr.parse_args::<IndexAnnotation>();
166                match parsed {
167                    Ok(parsed) => Ok(parsed),
168                    Err(err) => Err(err),
169                }
170            })
171            .collect::<Result<Vec<_>, _>>()?;
172
173        for index in index_attrs {
174            for query in index.to_query_strings(table_name, &field_name.to_string()) {
175                index_queries.push(query);
176            }
177        }
178    }
179
180    Ok((table_field_queries, index_queries))
181}
182
183enum FieldAnnotation {
184    Skip,
185    Plain,
186    Typed { type_: syn::LitStr },
187    CustomQuery { query: syn::LitStr },
188}
189
190/// parses the `#[field]` attribute
191///
192/// the `#[field]` attribute can have the following keys:
193/// - `skip`: if set, the field will be skipped
194/// - `type`: the type of the field
195impl Parse for FieldAnnotation {
196    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
197        let args: Punctuated<syn::Expr, syn::token::Comma> =
198            input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
199
200        if args.is_empty() {
201            return Ok(Self::Plain);
202        }
203
204        if args.len() > 1 {
205            return Err(syn::Error::new_spanned(
206                args,
207                "Field attribute can have at most one argument",
208            ));
209        }
210
211        match args.first() {
212            None => Ok(Self::Plain),
213            Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("skip") => {
214                Ok(Self::Skip)
215            }
216            Some(syn::Expr::Lit(ExprLit {
217                lit: syn::Lit::Str(strlit),
218                ..
219            })) => Ok(Self::CustomQuery {
220                query: strlit.clone(),
221            }),
222            Some(syn::Expr::Assign(ExprAssign { left, right, .. })) => {
223                if left.to_token_stream().to_string().eq("dt") {
224                    match *right.to_owned() {
225                        syn::Expr::Lit(ExprLit {
226                            lit: syn::Lit::Str(strlit),
227                            ..
228                        }) => Ok(Self::Typed { type_: strlit }),
229                        _ => Err(syn::Error::new_spanned(
230                            right,
231                            "The `dt` attribute expects a string literal",
232                        )),
233                    }
234                } else {
235                    Err(syn::Error::new_spanned(
236                        left,
237                        "Unknown field attribute, expected `dt`",
238                    ))
239                }
240            }
241            Some(expr) => Err(syn::Error::new_spanned(
242                expr,
243                "Unsupported expression syntax, expected `skip`, `dt = \"type\"`, or a string literal representing a custom query",
244            )),
245        }
246    }
247}
248
249#[derive(Default, Debug, Clone)]
250struct IndexAnnotation {
251    indexes: Vec<IndexAnnotationInner>,
252}
253
254#[derive(Debug, Clone)]
255enum IndexAnnotationInner {
256    Compound(CompoundIndexAnnotation),
257    Single(IndexKind),
258}
259
260impl Parse for IndexAnnotation {
261    /// Parses the `#[index]` attribute
262    ///
263    /// The syntax for compound attributes is:
264    /// ```ignore
265    /// #[index(compound(unique, "field1", "field2")]
266    /// ```
267    ///
268    /// where `unique` is any of the valid index types, and `"field1"` and `"field2"` are the names of the fields to be indexed.
269    /// - you can have more than 2 fields in a compound index, really any amount > 1.
270    /// - the last arguments to `compound` must be strings representing field names the compound index is created on.
271    /// - the first argument, if not a string must be a valid index type, or nothing to default to a normal index.
272    ///
273    /// Here are some representative examples of valid index types:
274    /// ```ignore
275    /// #[index(unique)]
276    /// #[index()]
277    /// #[index(vector(dim = 128))]
278    /// #[index(text("english"))]
279    /// #[index(compound(unique, "field1", "field2"))]
280    /// #[index(compound("field1", "field2"))]
281    /// ```
282    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
283        // TODO: error if more than one of the same type of index is specified on the same field
284        let args: Punctuated<syn::Expr, syn::token::Comma> =
285            input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
286
287        if args.is_empty() {
288            return Ok(Self {
289                indexes: vec![IndexAnnotationInner::Single(IndexKind::Normal)],
290            });
291        }
292
293        let mut indexes = Vec::new();
294        for arg in &args {
295            match arg {
296                syn::Expr::Call(call) if call.func.to_token_stream().to_string().eq("compound") => {
297                    // parse the compound index from the
298                    indexes.push(IndexAnnotationInner::Compound(
299                        CompoundIndexAnnotation::parse(&call.args)?,
300                    ));
301                }
302                _ => {
303                    // parse the index type from the arg
304                    let index_type = IndexKind::parse(Some(arg))?;
305                    indexes.push(IndexAnnotationInner::Single(index_type));
306                }
307            }
308        }
309
310        Ok(Self { indexes })
311    }
312}
313
314impl IndexAnnotation {
315    // if both vector and full-text indexes are set, return None
316    fn to_query_strings(&self, table_name: &str, field_name: &str) -> Vec<String> {
317        let mut output = Vec::new();
318        for index in &self.indexes {
319            let (compound, index_type) = match index {
320                IndexAnnotationInner::Compound(compound_index_annotation) => (
321                    Some(&compound_index_annotation.fields),
322                    &compound_index_annotation.index,
323                ),
324                IndexAnnotationInner::Single(index_kind) => (None, index_kind),
325            };
326
327            let (extra, index_type) = match index_type {
328                IndexKind::Vector(vector) => (format!(" MTREE DIMENSION {}", vector.dim), "vector"),
329                IndexKind::Text(text) => {
330                    (format!(" SEARCH ANALYZER {} BM25", text.analyzer), "text")
331                }
332                IndexKind::Normal => (String::new(), "normal"),
333                IndexKind::Unique => (String::from(" UNIQUE"), "unique"),
334            };
335            let compound_fields = |sep: &str| match compound {
336                Some(compound) if !compound.is_empty() => {
337                    format!("{sep}{}", compound.join(sep))
338                }
339                _ => String::new(),
340            };
341
342            let index_name = format!(
343                "{table_name}_{field_name}{extra_fields}_{index_type}_index",
344                extra_fields = compound_fields("_")
345            );
346
347            let query = format!(
348                "DEFINE INDEX {index_name} ON {table_name} FIELDS {field_name}{extra_fields}{extra};",
349                extra_fields = compound_fields(",")
350            );
351
352            output.push(query);
353        }
354
355        output
356    }
357}
358
359#[derive(Default, Debug, Clone)]
360/// A compound index is an index that is created across multiple fields.
361struct CompoundIndexAnnotation {
362    index: IndexKind,
363    fields: Vec<String>,
364}
365
366impl CompoundIndexAnnotation {
367    fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
368        let mut fields = Vec::new();
369
370        let mut args_iter = args.iter();
371
372        // the first argument (if not a string)
373        let index = match args_iter.next() {
374            Some(syn::Expr::Lit(ExprLit {
375                lit: syn::Lit::Str(strlit),
376                ..
377            })) => {
378                fields.push(strlit.value());
379                IndexKind::Normal
380            }
381            arg => match IndexKind::parse(arg) {
382                Ok(index_type) => index_type,
383                Err(mut err) => {
384                    err.combine(syn::Error::new_spanned(
385                            arg,
386                            "Compound index attribute expects a valid index type or string literal representing the first field name as the first argument",
387                        ));
388                    return Err(err);
389                }
390            },
391        };
392
393        // the remaining arguments should be string literals representing field names
394        for arg in args_iter {
395            match arg {
396                syn::Expr::Lit(ExprLit {
397                    lit: syn::Lit::Str(strlit),
398                    ..
399                }) => fields.push(strlit.value()),
400                _ => {
401                    return Err(syn::Error::new_spanned(
402                        arg,
403                        "Compound index attribute expects string literals representing the other field names",
404                    ));
405                }
406            }
407        }
408
409        if fields.is_empty() {
410            Err(syn::Error::new_spanned(
411                args,
412                "Compound index attribute expects at least one string literal representing the other field names",
413            ))
414        } else {
415            Ok(Self { index, fields })
416        }
417    }
418}
419
420#[derive(Default, Debug, Clone)]
421enum IndexKind {
422    Vector(VectorIndexAnnotation),
423    Text(TextIndexAnnotation),
424    #[default]
425    Normal,
426    Unique,
427}
428
429impl IndexKind {
430    fn parse(arg: Option<&syn::Expr>) -> syn::Result<Self> {
431        match arg {
432            None => Ok(Self::Normal),
433            Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("unique") => {
434                Ok(Self::Unique)
435            }
436            Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("vector") => {
437                Ok(Self::Vector(VectorIndexAnnotation::parse(&call.args)?))
438            }
439            Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("text") => {
440                Ok(Self::Text(TextIndexAnnotation::parse(&call.args)?))
441            }
442            _ => Err(syn::Error::new_spanned(
443                arg,
444                "Unsupported expression syntax",
445            )),
446        }
447    }
448}
449
450#[derive(Debug, Copy, Clone)]
451struct VectorIndexAnnotation {
452    dim: usize,
453}
454
455impl VectorIndexAnnotation {
456    fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
457        let mut args_iter = args.iter();
458        let arg = args_iter.next();
459        if args_iter.next().is_some() {
460            return Err(syn::Error::new_spanned(
461                args,
462                "Vector index attribute only expects one argument, the dimension of the vector",
463            ));
464        }
465
466        let dim = match arg {
467            Some(syn::Expr::Assign(ExprAssign { left, right, .. }))
468                if left.to_token_stream().to_string().eq("dim") =>
469            {
470                match *right.to_owned() {
471                    syn::Expr::Lit(ExprLit {
472                        lit: syn::Lit::Int(int),
473                        ..
474                    }) => int.base10_parse()?,
475                    _ => {
476                        return Err(syn::Error::new_spanned(
477                            right,
478                            "`dim` expects an integer literal representing the number of dimensions in the vector",
479                        ));
480                    }
481                }
482            }
483            Some(syn::Expr::Lit(ExprLit {
484                lit: syn::Lit::Int(int),
485                ..
486            })) => int.base10_parse()?,
487            _ => {
488                return Err(syn::Error::new_spanned(
489                    arg,
490                    "Unsupported expression syntax",
491                ));
492            }
493        };
494
495        if dim < 1 {
496            return Err(syn::Error::new_spanned(
497                arg,
498                "Vector dimension must be greater than 0",
499            ));
500        }
501
502        Ok(Self { dim })
503    }
504}
505
506#[derive(Debug, Clone)]
507struct TextIndexAnnotation {
508    analyzer: String,
509}
510
511impl TextIndexAnnotation {
512    fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
513        // should only read one argument, the analyzer (string literal)
514        let mut args_iter = args.iter();
515        let arg = args_iter.next();
516
517        if args_iter.next().is_some() {
518            return Err(syn::Error::new_spanned(
519                args,
520                "Text index attribute only expects one argument, the analyzer to use",
521            ));
522        }
523
524        let analyzer = match arg {
525            Some(syn::Expr::Lit(ExprLit {
526                lit: syn::Lit::Str(strlit),
527                ..
528            })) => strlit.value(),
529            _ => {
530                return Err(syn::Error::new_spanned(
531                    arg,
532                    "Text index attribute expects a string literal representing the analyzer to use",
533                ));
534            }
535        };
536
537        Ok(Self { analyzer })
538    }
539}