1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::{Data, DeriveInput, ExprAssign, ExprLit, parse::Parse, punctuated::Punctuated};
4
5#[cfg(test)]
6mod tests;
7
8pub fn table_macro_impl(input: TokenStream) -> syn::Result<TokenStream> {
14 let input = syn::parse2::<DeriveInput>(input)?;
15
16 let struct_name = &input.ident;
17
18 let table_name = parse_table_name(&input)?;
19
20 let struct_fields = parse_struct_fields(&input)?;
21
22 let (table_field_queries, index_queries) = parse_attributes(struct_fields, &table_name)?;
23
24 let table_query = format!("DEFINE TABLE {table_name} SCHEMAFULL;");
25
26 let table_field_queries = table_field_queries.iter().map(|q| quote! {.query(#q)});
27 let index_queries = index_queries.iter().map(|q| quote! {.query(#q)});
28
29 let expanded = quote! {
31 impl ::surrealqlx::traits::Table for #struct_name {
33 const TABLE_NAME: &'static str = #table_name;
34
35 #[allow(manual_async_fn)]
36 fn init_table<C: ::surrealdb::Connection>(
37 db: &::surrealdb::Surreal<C>,
38 ) -> impl ::std::future::Future<Output = ::surrealdb::Result<()>> + Send {
39 async {
40 let _ = db.query("BEGIN;")
41 .query(#table_query)
42 .query("COMMIT;")
43 .query("BEGIN;")
44 #(
45 #table_field_queries
46 )*
47 .query("COMMIT;")
48 .query("BEGIN;")
49 #(
50 #index_queries
51 )*
52 .query("COMMIT;").await?;
53 Ok(())
54 }
55 }
56 }
57 };
58
59 Ok(expanded)
61}
62
63fn parse_table_name(input: &DeriveInput) -> syn::Result<String> {
64 let table_name = input
65 .attrs
66 .iter()
67 .find(|attr| attr.path().is_ident("Table"))
68 .ok_or_else(|| {
69 syn::Error::new_spanned(input, "Table attribute must be specified for the struct")
70 })
71 .and_then(|attr| attr.parse_args::<syn::LitStr>().map(|lit| lit.value()))?;
72 Ok(table_name)
73}
74
75fn parse_struct_fields(input: &DeriveInput) -> syn::Result<impl Iterator<Item = &syn::Field>> {
77 match input.data {
78 Data::Struct(ref data) => match data.fields {
79 syn::Fields::Named(ref fields) => {
80 let mut fields = fields.named.iter().peekable();
81 if fields.peek().is_none() {
82 return Err(syn::Error::new_spanned(
83 input,
84 "Struct must have at least one field",
85 ));
86 }
87 Ok(fields)
88 }
89 _ => Err(syn::Error::new_spanned(
90 input,
91 "Tuple structs not supported",
92 )),
93 },
94 _ => Err(syn::Error::new_spanned(input, "Only structs are supported")),
95 }
96}
97
98fn parse_attributes<'a>(
100 fields: impl Iterator<Item = &'a syn::Field>,
101 table_name: &str,
102) -> syn::Result<(Vec<String>, Vec<String>)> {
103 let mut table_field_queries = Vec::new();
104
105 let mut index_queries = Vec::new();
106
107 for field in fields {
108 let Some(field_name) = field.ident.as_ref() else {
109 return Err(syn::Error::new_spanned(
110 field,
111 "Field must have a name, tuple structs not allowed",
112 ));
113 };
114 let mut field_attrs = field
115 .attrs
116 .iter()
117 .filter(|attr| attr.path().is_ident("field"))
118 .map(|attr| {
119 let parsed = attr.parse_args::<FieldAnnotation>();
120 match parsed {
121 Ok(parsed) => Ok((attr, parsed)),
122 Err(err) => Err(err),
123 }
124 })
125 .peekable();
126
127 let extra = match field_attrs.next() {
131 Some(Ok((_, FieldAnnotation::Skip))) => {
132 continue;
133 }
134 Some(Ok((_, FieldAnnotation::Plain))) => String::new(),
135 Some(Ok((_, FieldAnnotation::Typed { type_ }))) => format!(" TYPE {}", type_.value()),
136 Some(Ok((_, FieldAnnotation::CustomQuery { query }))) => {
137 format!(" {}", query.value())
138 }
139 Some(Err(err)) => {
140 return Err(err);
141 }
142 None => {
143 return Err(syn::Error::new_spanned(
144 field,
145 "Field must have a #[field] attribute",
146 ));
147 }
148 };
149 if field_attrs.peek().is_some() {
151 return Err(syn::Error::new_spanned(
152 field,
153 "Field can have only one #[field] attribute",
154 ));
155 }
156
157 table_field_queries.push(format!("DEFINE FIELD {field_name} ON {table_name}{extra};",));
158
159 let index_attrs = field
161 .attrs
162 .iter()
163 .filter(|attr| attr.path().is_ident("index"))
164 .map(|attr| {
165 let parsed = attr.parse_args::<IndexAnnotation>();
166 match parsed {
167 Ok(parsed) => Ok(parsed),
168 Err(err) => Err(err),
169 }
170 })
171 .collect::<Result<Vec<_>, _>>()?;
172
173 for index in index_attrs {
174 for query in index.to_query_strings(table_name, &field_name.to_string()) {
175 index_queries.push(query);
176 }
177 }
178 }
179
180 Ok((table_field_queries, index_queries))
181}
182
183enum FieldAnnotation {
184 Skip,
185 Plain,
186 Typed { type_: syn::LitStr },
187 CustomQuery { query: syn::LitStr },
188}
189
190impl Parse for FieldAnnotation {
196 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
197 let args: Punctuated<syn::Expr, syn::token::Comma> =
198 input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
199
200 if args.is_empty() {
201 return Ok(Self::Plain);
202 }
203
204 if args.len() > 1 {
205 return Err(syn::Error::new_spanned(
206 args,
207 "Field attribute can have at most one argument",
208 ));
209 }
210
211 match args.first() {
212 None => Ok(Self::Plain),
213 Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("skip") => {
214 Ok(Self::Skip)
215 }
216 Some(syn::Expr::Lit(ExprLit {
217 lit: syn::Lit::Str(strlit),
218 ..
219 })) => Ok(Self::CustomQuery {
220 query: strlit.clone(),
221 }),
222 Some(syn::Expr::Assign(ExprAssign { left, right, .. })) => {
223 if left.to_token_stream().to_string().eq("dt") {
224 match *right.to_owned() {
225 syn::Expr::Lit(ExprLit {
226 lit: syn::Lit::Str(strlit),
227 ..
228 }) => Ok(Self::Typed { type_: strlit }),
229 _ => Err(syn::Error::new_spanned(
230 right,
231 "The `dt` attribute expects a string literal",
232 )),
233 }
234 } else {
235 Err(syn::Error::new_spanned(
236 left,
237 "Unknown field attribute, expected `dt`",
238 ))
239 }
240 }
241 Some(expr) => Err(syn::Error::new_spanned(
242 expr,
243 "Unsupported expression syntax, expected `skip`, `dt = \"type\"`, or a string literal representing a custom query",
244 )),
245 }
246 }
247}
248
249#[derive(Default, Debug, Clone)]
250struct IndexAnnotation {
251 indexes: Vec<IndexAnnotationInner>,
252}
253
254#[derive(Debug, Clone)]
255enum IndexAnnotationInner {
256 Compound(CompoundIndexAnnotation),
257 Single(IndexKind),
258}
259
260impl Parse for IndexAnnotation {
261 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
283 let args: Punctuated<syn::Expr, syn::token::Comma> =
285 input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
286
287 if args.is_empty() {
288 return Ok(Self {
289 indexes: vec![IndexAnnotationInner::Single(IndexKind::Normal)],
290 });
291 }
292
293 let mut indexes = Vec::new();
294 for arg in &args {
295 match arg {
296 syn::Expr::Call(call) if call.func.to_token_stream().to_string().eq("compound") => {
297 indexes.push(IndexAnnotationInner::Compound(
299 CompoundIndexAnnotation::parse(&call.args)?,
300 ));
301 }
302 _ => {
303 let index_type = IndexKind::parse(Some(arg))?;
305 indexes.push(IndexAnnotationInner::Single(index_type));
306 }
307 }
308 }
309
310 Ok(Self { indexes })
311 }
312}
313
314impl IndexAnnotation {
315 fn to_query_strings(&self, table_name: &str, field_name: &str) -> Vec<String> {
317 let mut output = Vec::new();
318 for index in &self.indexes {
319 let (compound, index_type) = match index {
320 IndexAnnotationInner::Compound(compound_index_annotation) => (
321 Some(&compound_index_annotation.fields),
322 &compound_index_annotation.index,
323 ),
324 IndexAnnotationInner::Single(index_kind) => (None, index_kind),
325 };
326
327 let (extra, index_type) = match index_type {
328 IndexKind::Vector(vector) => (format!(" MTREE DIMENSION {}", vector.dim), "vector"),
329 IndexKind::Text(text) => {
330 (format!(" SEARCH ANALYZER {} BM25", text.analyzer), "text")
331 }
332 IndexKind::Normal => (String::new(), "normal"),
333 IndexKind::Unique => (String::from(" UNIQUE"), "unique"),
334 };
335 let compound_fields = |sep: &str| match compound {
336 Some(compound) if !compound.is_empty() => {
337 format!("{sep}{}", compound.join(sep))
338 }
339 _ => String::new(),
340 };
341
342 let index_name = format!(
343 "{table_name}_{field_name}{extra_fields}_{index_type}_index",
344 extra_fields = compound_fields("_")
345 );
346
347 let query = format!(
348 "DEFINE INDEX {index_name} ON {table_name} FIELDS {field_name}{extra_fields}{extra};",
349 extra_fields = compound_fields(",")
350 );
351
352 output.push(query);
353 }
354
355 output
356 }
357}
358
359#[derive(Default, Debug, Clone)]
360struct CompoundIndexAnnotation {
362 index: IndexKind,
363 fields: Vec<String>,
364}
365
366impl CompoundIndexAnnotation {
367 fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
368 let mut fields = Vec::new();
369
370 let mut args_iter = args.iter();
371
372 let index = match args_iter.next() {
374 Some(syn::Expr::Lit(ExprLit {
375 lit: syn::Lit::Str(strlit),
376 ..
377 })) => {
378 fields.push(strlit.value());
379 IndexKind::Normal
380 }
381 arg => match IndexKind::parse(arg) {
382 Ok(index_type) => index_type,
383 Err(mut err) => {
384 err.combine(syn::Error::new_spanned(
385 arg,
386 "Compound index attribute expects a valid index type or string literal representing the first field name as the first argument",
387 ));
388 return Err(err);
389 }
390 },
391 };
392
393 for arg in args_iter {
395 match arg {
396 syn::Expr::Lit(ExprLit {
397 lit: syn::Lit::Str(strlit),
398 ..
399 }) => fields.push(strlit.value()),
400 _ => {
401 return Err(syn::Error::new_spanned(
402 arg,
403 "Compound index attribute expects string literals representing the other field names",
404 ));
405 }
406 }
407 }
408
409 if fields.is_empty() {
410 Err(syn::Error::new_spanned(
411 args,
412 "Compound index attribute expects at least one string literal representing the other field names",
413 ))
414 } else {
415 Ok(Self { index, fields })
416 }
417 }
418}
419
420#[derive(Default, Debug, Clone)]
421enum IndexKind {
422 Vector(VectorIndexAnnotation),
423 Text(TextIndexAnnotation),
424 #[default]
425 Normal,
426 Unique,
427}
428
429impl IndexKind {
430 fn parse(arg: Option<&syn::Expr>) -> syn::Result<Self> {
431 match arg {
432 None => Ok(Self::Normal),
433 Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("unique") => {
434 Ok(Self::Unique)
435 }
436 Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("vector") => {
437 Ok(Self::Vector(VectorIndexAnnotation::parse(&call.args)?))
438 }
439 Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("text") => {
440 Ok(Self::Text(TextIndexAnnotation::parse(&call.args)?))
441 }
442 _ => Err(syn::Error::new_spanned(
443 arg,
444 "Unsupported expression syntax",
445 )),
446 }
447 }
448}
449
450#[derive(Debug, Copy, Clone)]
451struct VectorIndexAnnotation {
452 dim: usize,
453}
454
455impl VectorIndexAnnotation {
456 fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
457 let mut args_iter = args.iter();
458 let arg = args_iter.next();
459 if args_iter.next().is_some() {
460 return Err(syn::Error::new_spanned(
461 args,
462 "Vector index attribute only expects one argument, the dimension of the vector",
463 ));
464 }
465
466 let dim = match arg {
467 Some(syn::Expr::Assign(ExprAssign { left, right, .. }))
468 if left.to_token_stream().to_string().eq("dim") =>
469 {
470 match *right.to_owned() {
471 syn::Expr::Lit(ExprLit {
472 lit: syn::Lit::Int(int),
473 ..
474 }) => int.base10_parse()?,
475 _ => {
476 return Err(syn::Error::new_spanned(
477 right,
478 "`dim` expects an integer literal representing the number of dimensions in the vector",
479 ));
480 }
481 }
482 }
483 Some(syn::Expr::Lit(ExprLit {
484 lit: syn::Lit::Int(int),
485 ..
486 })) => int.base10_parse()?,
487 _ => {
488 return Err(syn::Error::new_spanned(
489 arg,
490 "Unsupported expression syntax",
491 ));
492 }
493 };
494
495 if dim < 1 {
496 return Err(syn::Error::new_spanned(
497 arg,
498 "Vector dimension must be greater than 0",
499 ));
500 }
501
502 Ok(Self { dim })
503 }
504}
505
506#[derive(Debug, Clone)]
507struct TextIndexAnnotation {
508 analyzer: String,
509}
510
511impl TextIndexAnnotation {
512 fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
513 let mut args_iter = args.iter();
515 let arg = args_iter.next();
516
517 if args_iter.next().is_some() {
518 return Err(syn::Error::new_spanned(
519 args,
520 "Text index attribute only expects one argument, the analyzer to use",
521 ));
522 }
523
524 let analyzer = match arg {
525 Some(syn::Expr::Lit(ExprLit {
526 lit: syn::Lit::Str(strlit),
527 ..
528 })) => strlit.value(),
529 _ => {
530 return Err(syn::Error::new_spanned(
531 arg,
532 "Text index attribute expects a string literal representing the analyzer to use",
533 ));
534 }
535 };
536
537 Ok(Self { analyzer })
538 }
539}