Skip to main content

surrealqlx_macros_impl/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::{Data, DeriveInput, ExprAssign, ExprLit, parse::Parse, punctuated::Punctuated};
4
5#[cfg(test)]
6mod tests;
7
8mod surrql;
9pub use surrql::surrql_macro_impl;
10
11/// Implementation of the Table derive macro
12///
13/// # Errors
14///
15/// This function will return an error if the input couldn't be parsed, or if attributes are missing or invalid.
16pub fn table_macro_impl(input: TokenStream) -> syn::Result<TokenStream> {
17    let input = syn::parse2::<DeriveInput>(input)?;
18
19    let struct_name = &input.ident;
20
21    let table_name = parse_table_name(&input)?;
22
23    let struct_fields = parse_struct_fields(&input)?;
24
25    let (table_field_queries, index_queries) = parse_attributes(struct_fields, &table_name)?;
26
27    let table_query = format!("DEFINE TABLE IF NOT EXISTS {table_name} SCHEMAFULL;");
28
29    let migration_up_query = format!(
30        "{table_query}\n{}",
31        table_field_queries
32            .into_iter()
33            .chain(index_queries)
34            .collect::<Vec<_>>()
35            .join("\n")
36    );
37
38    // Build the output, possibly using the input
39    let expanded = quote! {
40        // The generated impl goes here
41        impl ::surrealqlx::traits::Table for #struct_name {
42            const TABLE_NAME: &'static str = #table_name;
43
44            fn migrations() -> Vec<::surrealqlx::migrations::M<'static>> {
45                vec![
46                    ::surrealqlx::migrations::M::up(
47                        #migration_up_query
48                    )
49                    .comment("Initial version"),
50                ]
51            }
52        }
53    };
54
55    // Hand the output tokens back to the compiler
56    Ok(expanded)
57}
58
59fn parse_table_name(input: &DeriveInput) -> syn::Result<String> {
60    let table_name = input
61        .attrs
62        .iter()
63        .find(|attr| attr.path().is_ident("Table"))
64        .ok_or_else(|| {
65            syn::Error::new_spanned(input, "Table attribute must be specified for the struct")
66        })
67        .and_then(|attr| attr.parse_args::<syn::LitStr>().map(|lit| lit.value()))?;
68    Ok(table_name)
69}
70
71/// Get's the fields of the struct
72fn parse_struct_fields(input: &DeriveInput) -> syn::Result<impl Iterator<Item = &syn::Field>> {
73    match input.data {
74        Data::Struct(ref data) => match data.fields {
75            syn::Fields::Named(ref fields) => {
76                let mut fields = fields.named.iter().peekable();
77                if fields.peek().is_none() {
78                    return Err(syn::Error::new_spanned(
79                        input,
80                        "Struct must have at least one field",
81                    ));
82                }
83                Ok(fields)
84            }
85            _ => Err(syn::Error::new_spanned(
86                input,
87                "Tuple structs not supported",
88            )),
89        },
90        _ => Err(syn::Error::new_spanned(input, "Only structs are supported")),
91    }
92}
93
94/// Parses the `#[field]` and `#[index]` attributes on the fields of the struct
95fn parse_attributes<'a>(
96    fields: impl Iterator<Item = &'a syn::Field>,
97    table_name: &str,
98) -> syn::Result<(Vec<String>, Vec<String>)> {
99    let mut table_field_queries = Vec::new();
100
101    let mut index_queries = Vec::new();
102
103    for field in fields {
104        let Some(field_name) = field.ident.as_ref() else {
105            return Err(syn::Error::new_spanned(
106                field,
107                "Field must have a name, tuple structs not allowed",
108            ));
109        };
110        let mut field_attrs = field
111            .attrs
112            .iter()
113            .filter(|attr| attr.path().is_ident("field"))
114            .map(|attr| {
115                let parsed = attr.parse_args::<FieldAnnotation>();
116                match parsed {
117                    Ok(parsed) => Ok((attr, parsed)),
118                    Err(err) => Err(err),
119                }
120            })
121            .peekable();
122
123        // process the field attribute
124
125        // what (if anything) should be appended to the plain field definition query
126        let extra = match field_attrs.next() {
127            Some(Ok((_, FieldAnnotation::Skip))) => {
128                continue;
129            }
130            Some(Ok((_, FieldAnnotation::Plain))) => String::new(),
131            Some(Ok((_, FieldAnnotation::Typed { type_ }))) => format!(" TYPE {}", type_.value()),
132            Some(Ok((_, FieldAnnotation::CustomQuery { query }))) => {
133                format!(" {}", query.value())
134            }
135            Some(Err(err)) => {
136                return Err(err);
137            }
138            None => {
139                return Err(syn::Error::new_spanned(
140                    field,
141                    "Field must have a #[field] attribute",
142                ));
143            }
144        };
145        // next, make sure there was only one field attribute
146        if field_attrs.peek().is_some() {
147            return Err(syn::Error::new_spanned(
148                field,
149                "Field can have only one #[field] attribute",
150            ));
151        }
152
153        table_field_queries.push(format!(
154            "DEFINE FIELD IF NOT EXISTS {field_name} ON {table_name}{extra};",
155        ));
156
157        // next, we process the index attribute(s)
158        let index_attrs = field
159            .attrs
160            .iter()
161            .filter(|attr| attr.path().is_ident("index"))
162            .map(|attr| {
163                let parsed = attr.parse_args::<IndexAnnotation>();
164                match parsed {
165                    Ok(parsed) => Ok(parsed),
166                    Err(err) => Err(err),
167                }
168            })
169            .collect::<Result<Vec<_>, _>>()?;
170
171        for index in index_attrs {
172            for query in index.to_query_strings(table_name, &field_name.to_string()) {
173                index_queries.push(query);
174            }
175        }
176    }
177
178    Ok((table_field_queries, index_queries))
179}
180
181enum FieldAnnotation {
182    Skip,
183    Plain,
184    Typed { type_: syn::LitStr },
185    CustomQuery { query: syn::LitStr },
186}
187
188/// parses the `#[field]` attribute
189///
190/// the `#[field]` attribute can have the following keys:
191/// - `skip`: if set, the field will be skipped
192/// - `type`: the type of the field
193impl Parse for FieldAnnotation {
194    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
195        let args: Punctuated<syn::Expr, syn::token::Comma> =
196            input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
197
198        if args.is_empty() {
199            return Ok(Self::Plain);
200        }
201
202        if args.len() > 1 {
203            return Err(syn::Error::new_spanned(
204                args,
205                "Field attribute can have at most one argument",
206            ));
207        }
208
209        match args.first() {
210            None => Ok(Self::Plain),
211            Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("skip") => {
212                Ok(Self::Skip)
213            }
214            Some(syn::Expr::Lit(ExprLit {
215                lit: syn::Lit::Str(strlit),
216                ..
217            })) => Ok(Self::CustomQuery {
218                query: strlit.clone(),
219            }),
220            Some(syn::Expr::Assign(ExprAssign { left, right, .. })) => {
221                if left.to_token_stream().to_string().eq("dt") {
222                    match *right.to_owned() {
223                        syn::Expr::Lit(ExprLit {
224                            lit: syn::Lit::Str(strlit),
225                            ..
226                        }) => Ok(Self::Typed { type_: strlit }),
227                        _ => Err(syn::Error::new_spanned(
228                            right,
229                            "The `dt` attribute expects a string literal",
230                        )),
231                    }
232                } else {
233                    Err(syn::Error::new_spanned(
234                        left,
235                        "Unknown field attribute, expected `dt`",
236                    ))
237                }
238            }
239            Some(expr) => Err(syn::Error::new_spanned(
240                expr,
241                "Unsupported expression syntax, expected `skip`, `dt = \"type\"`, or a string literal representing a custom query",
242            )),
243        }
244    }
245}
246
247#[derive(Default, Debug, Clone)]
248struct IndexAnnotation {
249    indexes: Vec<IndexAnnotationInner>,
250}
251
252#[derive(Debug, Clone)]
253enum IndexAnnotationInner {
254    Compound(CompoundIndexAnnotation),
255    Single(IndexKind),
256}
257
258impl Parse for IndexAnnotation {
259    /// Parses the `#[index]` attribute
260    ///
261    /// The syntax for compound attributes is:
262    /// ```ignore
263    /// #[index(compound(unique, "field1", "field2")]
264    /// ```
265    ///
266    /// where `unique` is any of the valid index types, and `"field1"` and `"field2"` are the names of the fields to be indexed.
267    /// - you can have more than 2 fields in a compound index, really any amount > 1.
268    /// - the last arguments to `compound` must be strings representing field names the compound index is created on.
269    /// - the first argument, if not a string must be a valid index type, or nothing to default to a normal index.
270    ///
271    /// Here are some representative examples of valid index types:
272    /// ```ignore
273    /// #[index(unique)]
274    /// #[index()]
275    /// #[index(vector(dim = 128))]
276    /// #[index(text("english"))]
277    /// #[index(compound(unique, "field1", "field2"))]
278    /// #[index(compound("field1", "field2"))]
279    /// ```
280    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
281        // TODO: error if more than one of the same type of index is specified on the same field
282        let args: Punctuated<syn::Expr, syn::token::Comma> =
283            input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
284
285        if args.is_empty() {
286            return Ok(Self {
287                indexes: vec![IndexAnnotationInner::Single(IndexKind::Normal)],
288            });
289        }
290
291        let mut indexes = Vec::new();
292        for arg in &args {
293            match arg {
294                syn::Expr::Call(call) if call.func.to_token_stream().to_string().eq("compound") => {
295                    // parse the compound index from the
296                    indexes.push(IndexAnnotationInner::Compound(
297                        CompoundIndexAnnotation::parse(&call.args)?,
298                    ));
299                }
300                _ => {
301                    // parse the index type from the arg
302                    let index_type = IndexKind::parse(Some(arg))?;
303                    indexes.push(IndexAnnotationInner::Single(index_type));
304                }
305            }
306        }
307
308        Ok(Self { indexes })
309    }
310}
311
312impl IndexAnnotation {
313    // if both vector and full-text indexes are set, return None
314    fn to_query_strings(&self, table_name: &str, field_name: &str) -> Vec<String> {
315        let mut output = Vec::new();
316        for index in &self.indexes {
317            let (compound, index_type) = match index {
318                IndexAnnotationInner::Compound(compound_index_annotation) => (
319                    Some(&compound_index_annotation.fields),
320                    &compound_index_annotation.index,
321                ),
322                IndexAnnotationInner::Single(index_kind) => (None, index_kind),
323            };
324
325            let (extra, index_type) = match index_type {
326                IndexKind::Vector(vector) => (format!(" MTREE DIMENSION {}", vector.dim), "vector"),
327                IndexKind::Text(text) => {
328                    (format!(" SEARCH ANALYZER {} BM25", text.analyzer), "text")
329                }
330                IndexKind::Normal => (String::new(), "normal"),
331                IndexKind::Unique => (String::from(" UNIQUE"), "unique"),
332            };
333            let compound_fields = |sep: &str| match compound {
334                Some(compound) if !compound.is_empty() => {
335                    format!("{sep}{}", compound.join(sep))
336                }
337                _ => String::new(),
338            };
339
340            let index_name = format!(
341                "{table_name}_{field_name}{extra_fields}_{index_type}_index",
342                extra_fields = compound_fields("_")
343            );
344
345            let query = format!(
346                "DEFINE INDEX IF NOT EXISTS {index_name} ON {table_name} FIELDS {field_name}{extra_fields}{extra};",
347                extra_fields = compound_fields(",")
348            );
349
350            output.push(query);
351        }
352
353        output
354    }
355}
356
357#[derive(Default, Debug, Clone)]
358/// A compound index is an index that is created across multiple fields.
359struct CompoundIndexAnnotation {
360    index: IndexKind,
361    fields: Vec<String>,
362}
363
364impl CompoundIndexAnnotation {
365    fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
366        let mut fields = Vec::new();
367
368        let mut args_iter = args.iter();
369
370        // the first argument (if not a string)
371        let index = match args_iter.next() {
372            Some(syn::Expr::Lit(ExprLit {
373                lit: syn::Lit::Str(strlit),
374                ..
375            })) => {
376                fields.push(strlit.value());
377                IndexKind::Normal
378            }
379            arg => match IndexKind::parse(arg) {
380                Ok(index_type) => index_type,
381                Err(mut err) => {
382                    err.combine(syn::Error::new_spanned(
383                            arg,
384                            "Compound index attribute expects a valid index type or string literal representing the first field name as the first argument",
385                        ));
386                    return Err(err);
387                }
388            },
389        };
390
391        // the remaining arguments should be string literals representing field names
392        for arg in args_iter {
393            match arg {
394                syn::Expr::Lit(ExprLit {
395                    lit: syn::Lit::Str(strlit),
396                    ..
397                }) => fields.push(strlit.value()),
398                _ => {
399                    return Err(syn::Error::new_spanned(
400                        arg,
401                        "Compound index attribute expects string literals representing the other field names",
402                    ));
403                }
404            }
405        }
406
407        if fields.is_empty() {
408            Err(syn::Error::new_spanned(
409                args,
410                "Compound index attribute expects at least one string literal representing the other field names",
411            ))
412        } else {
413            Ok(Self { index, fields })
414        }
415    }
416}
417
418#[derive(Default, Debug, Clone)]
419enum IndexKind {
420    Vector(VectorIndexAnnotation),
421    Text(TextIndexAnnotation),
422    #[default]
423    Normal,
424    Unique,
425}
426
427impl IndexKind {
428    fn parse(arg: Option<&syn::Expr>) -> syn::Result<Self> {
429        match arg {
430            None => Ok(Self::Normal),
431            Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("unique") => {
432                Ok(Self::Unique)
433            }
434            Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("vector") => {
435                Ok(Self::Vector(VectorIndexAnnotation::parse(&call.args)?))
436            }
437            Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("text") => {
438                Ok(Self::Text(TextIndexAnnotation::parse(&call.args)?))
439            }
440            _ => Err(syn::Error::new_spanned(
441                arg,
442                "Unsupported expression syntax",
443            )),
444        }
445    }
446}
447
448#[derive(Debug, Copy, Clone)]
449struct VectorIndexAnnotation {
450    dim: usize,
451}
452
453impl VectorIndexAnnotation {
454    fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
455        let mut args_iter = args.iter();
456        let arg = args_iter.next();
457        if args_iter.next().is_some() {
458            return Err(syn::Error::new_spanned(
459                args,
460                "Vector index attribute only expects one argument, the dimension of the vector",
461            ));
462        }
463
464        let dim = match arg {
465            Some(syn::Expr::Assign(ExprAssign { left, right, .. }))
466                if left.to_token_stream().to_string().eq("dim") =>
467            {
468                match *right.to_owned() {
469                    syn::Expr::Lit(ExprLit {
470                        lit: syn::Lit::Int(int),
471                        ..
472                    }) => int.base10_parse()?,
473                    _ => {
474                        return Err(syn::Error::new_spanned(
475                            right,
476                            "`dim` expects an integer literal representing the number of dimensions in the vector",
477                        ));
478                    }
479                }
480            }
481            Some(syn::Expr::Lit(ExprLit {
482                lit: syn::Lit::Int(int),
483                ..
484            })) => int.base10_parse()?,
485            _ => {
486                return Err(syn::Error::new_spanned(
487                    arg,
488                    "Unsupported expression syntax",
489                ));
490            }
491        };
492
493        if dim < 1 {
494            return Err(syn::Error::new_spanned(
495                arg,
496                "Vector dimension must be greater than 0",
497            ));
498        }
499
500        Ok(Self { dim })
501    }
502}
503
504#[derive(Debug, Clone)]
505struct TextIndexAnnotation {
506    analyzer: String,
507}
508
509impl TextIndexAnnotation {
510    fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
511        // should only read one argument, the analyzer (string literal)
512        let mut args_iter = args.iter();
513        let arg = args_iter.next();
514
515        if args_iter.next().is_some() {
516            return Err(syn::Error::new_spanned(
517                args,
518                "Text index attribute only expects one argument, the analyzer to use",
519            ));
520        }
521
522        let analyzer = match arg {
523            Some(syn::Expr::Lit(ExprLit {
524                lit: syn::Lit::Str(strlit),
525                ..
526            })) => strlit.value(),
527            _ => {
528                return Err(syn::Error::new_spanned(
529                    arg,
530                    "Text index attribute expects a string literal representing the analyzer to use",
531                ));
532            }
533        };
534
535        Ok(Self { analyzer })
536    }
537}