Skip to main content

rivet_error_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{DeriveInput, Fields, LitStr, Meta, parse_macro_input};
4
5#[proc_macro_derive(RivetError, attributes(error))]
6pub fn derive_rivet_error(input: TokenStream) -> TokenStream {
7	let input = parse_macro_input!(input as DeriveInput);
8
9	// Check if this is a struct or an enum
10	match &input.data.clone() {
11		syn::Data::Struct(data_struct) => derive_struct_error(input, data_struct),
12		syn::Data::Enum(data_enum) => derive_enum_error(input, data_enum),
13		_ => panic!("RivetError can only be derived for structs and enums"),
14	}
15}
16
17fn derive_struct_error(input: DeriveInput, data_struct: &syn::DataStruct) -> TokenStream {
18	let struct_name = &input.ident;
19	let vis = &input.vis;
20
21	// Extract error attributes
22	let error_attr = input
23		.attrs
24		.iter()
25		.find(|attr| attr.path().is_ident("error"))
26		.expect("RivetError requires #[error(...)] attribute");
27
28	let args = match &error_attr.meta {
29		Meta::List(meta_list) => {
30			let tokens = &meta_list.tokens;
31			syn::parse2::<ErrorArgs>(tokens.clone())
32				.expect("Failed to parse error attribute arguments")
33		}
34		_ => panic!("error attribute must be in the form #[error(...)]"),
35	};
36
37	// Generate the schema creation
38	let group = &args.group;
39	let code = &args.code;
40	let description = &args.description;
41
42	// Generate the output based on whether we have fields
43	let output = match &data_struct.fields {
44		Fields::Named(fields) => {
45			let field_names = fields.named.iter().map(|f| &f.ident).collect::<Vec<_>>();
46
47			if let Some(formatted) = &args.formatted_desc {
48				quote! {
49					impl #struct_name {
50						#vis fn build(self) -> ::anyhow::Error {
51							use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
52
53							#[allow(non_upper_case_globals)]
54							static SCHEMA: RivetErrorSchemaWithMeta<#struct_name> = RivetErrorSchemaWithMeta {
55								schema: RivetErrorSchema {
56									group: #group,
57									code: #code,
58									default_message: #description,
59									meta_type: Some(stringify!(#struct_name)),
60									_macro_marker: MacroMarker { _private: () },
61								},
62								message_fn: |meta: &#struct_name| -> String {
63									::rivet_error::indoc::formatdoc! {
64										#formatted,
65										#(#field_names = meta.#field_names),*
66									}
67								},
68								_phantom: ::std::marker::PhantomData,
69							};
70
71							SCHEMA.build_with(self)
72						}
73					}
74				}
75			} else {
76				quote! {
77					impl #struct_name {
78						#vis fn build(self) -> ::anyhow::Error {
79							use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
80
81							#[allow(non_upper_case_globals)]
82							static SCHEMA: RivetErrorSchema = RivetErrorSchema {
83								group: #group,
84								code: #code,
85								default_message: #description,
86								meta_type: Some(stringify!(#struct_name)),
87								_macro_marker: MacroMarker { _private: () },
88							};
89
90							let meta_json = ::serde_json::to_value(&self)
91								.ok()
92								.and_then(|v| ::serde_json::value::to_raw_value(&v).ok());
93
94							let error = RivetError {
95								kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
96								meta: meta_json,
97								message: None,
98								actor: None,
99							};
100							::anyhow::Error::new(error)
101						}
102					}
103				}
104			}
105		}
106		Fields::Unnamed(fields) => {
107			let field_count = fields.unnamed.len();
108			let field_names = (0..field_count)
109				.map(|i| syn::Ident::new(&format!("field{}", i), proc_macro2::Span::call_site()))
110				.collect::<Vec<_>>();
111
112			if let Some(formatted) = &args.formatted_desc {
113				let struct_meta_fields = field_names
114					.iter()
115					.zip(fields.unnamed.iter())
116					.map(|(field_name, field)| {
117						let field_type = &field.ty;
118						quote! { #field_name: #field_type }
119					})
120					.collect::<Vec<_>>();
121				let meta_fields = field_names
122					.iter()
123					.enumerate()
124					.map(|(i, field_name)| {
125						let idx = syn::Index::from(i);
126						quote! { #field_name: self.#idx }
127					})
128					.collect::<Vec<_>>();
129
130				quote! {
131					impl #struct_name {
132						#vis fn build(self) -> ::anyhow::Error {
133							use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
134
135							#[derive(::serde::Serialize)]
136							struct StructMeta {
137								#(#struct_meta_fields),*
138							}
139
140							let meta = StructMeta {
141								#(#meta_fields),*
142							};
143
144							#[allow(non_upper_case_globals)]
145							static SCHEMA: RivetErrorSchemaWithMeta<StructMeta> = RivetErrorSchemaWithMeta {
146								schema: RivetErrorSchema {
147									group: #group,
148									code: #code,
149									default_message: #description,
150									meta_type: Some(stringify!(#struct_name)),
151									_macro_marker: MacroMarker { _private: () },
152								},
153								message_fn: |meta: &StructMeta| -> String {
154									::rivet_error::indoc::formatdoc! {
155										#formatted,
156										#(meta.#field_names),*
157									}
158								},
159								_phantom: ::std::marker::PhantomData,
160							};
161
162							SCHEMA.build_with(meta)
163						}
164					}
165				}
166			} else {
167				let json_fields = field_names
168					.iter()
169					.map(|field_name| {
170						let field_name_str = field_name.to_string();
171
172						quote! { #field_name_str: #field_name }
173					})
174					.collect::<Vec<_>>();
175
176				quote! {
177					impl #struct_name {
178						#vis fn build(self) -> ::anyhow::Error {
179							use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
180
181							#[allow(non_upper_case_globals)]
182							static SCHEMA: RivetErrorSchema = RivetErrorSchema {
183								group: #group,
184								code: #code,
185								default_message: #description,
186								meta_type: Some(stringify!(#struct_name)),
187								_macro_marker: MacroMarker { _private: () },
188							};
189
190							let meta_value = ::serde_json::json!({
191								#(#json_fields),*
192							});
193
194							let meta_json = ::serde_json::value::to_raw_value(&meta_value).ok();
195
196							let error = RivetError {
197								kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
198								meta: meta_json,
199								message: None,
200								actor: None,
201							};
202							::anyhow::Error::new(error)
203						}
204					}
205				}
206			}
207		}
208		Fields::Unit => {
209			quote! {
210				impl #struct_name {
211					#vis fn build(self) -> ::anyhow::Error {
212						use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
213
214						#[allow(non_upper_case_globals)]
215						static SCHEMA: RivetErrorSchema = RivetErrorSchema {
216							group: #group,
217							code: #code,
218							default_message: #description,
219							meta_type: None,
220							_macro_marker: MacroMarker { _private: () },
221						};
222
223						SCHEMA.build()
224					}
225				}
226			}
227		}
228	};
229
230	// Write error documentation
231	if let Err(e) = write_error_doc(&args.group, &args.code, &args.description) {
232		panic!(
233			"Failed to write error documentation for {}.{}: {}",
234			args.group, args.code, e
235		);
236	}
237
238	// eprintln!("\n\n{output}\n");
239
240	TokenStream::from(output)
241}
242
243fn derive_enum_error(input: DeriveInput, data_enum: &syn::DataEnum) -> TokenStream {
244	let enum_name = &input.ident;
245	let vis = &input.vis;
246
247	// Extract group name from enum-level error attribute
248	let error_attr = input
249		.attrs
250		.iter()
251		.find(|attr| attr.path().is_ident("error"))
252		.expect("RivetError on enum requires #[error(\"group\")] attribute");
253
254	let group = match &error_attr.meta {
255		Meta::List(meta_list) => {
256			let tokens = &meta_list.tokens;
257			let group_str = syn::parse2::<LitStr>(tokens.clone())
258				.expect("Failed to parse enum error attribute arguments");
259			group_str.value()
260		}
261		_ => panic!("error attribute for enum must be in the form #[error(\"group\")]"),
262	};
263
264	let mut variant_matches = Vec::new();
265
266	// Process each variant
267	for variant in &data_enum.variants {
268		let variant_name = &variant.ident;
269
270		// Extract error attributes from variant
271		let variant_error_attr = variant
272			.attrs
273			.iter()
274			.find(|attr| attr.path().is_ident("error"))
275			.expect(&format!(
276				"Variant {} requires #[error(...)] attribute",
277				variant_name
278			));
279
280		let (code, description, formatted_desc) = match &variant_error_attr.meta {
281			Meta::List(meta_list) => {
282				let tokens = &meta_list.tokens;
283				parse_variant_error_args(tokens).expect(&format!(
284					"Failed to parse variant error attributes for {}",
285					variant_name
286				))
287			}
288			_ => panic!(
289				"error attribute for variant must be in the form #[error(\"code\", \"description\")]"
290			),
291		};
292
293		// Write error documentation
294		if let Err(e) = write_error_doc(&group, &code, &description) {
295			panic!(
296				"Failed to write error documentation for {}.{}: {}",
297				group, code, e
298			);
299		}
300
301		// Handle variants with fields
302		match &variant.fields {
303			Fields::Named(fields) => {
304				let field_names = fields.named.iter().map(|f| &f.ident).collect::<Vec<_>>();
305				let field_patterns = quote! { { #(#field_names),* } };
306
307				if let Some(formatted) = &formatted_desc {
308					variant_matches.push(quote! {
309						#enum_name::#variant_name #field_patterns => {
310							use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
311
312							#[derive(Serialize)]
313							struct VariantMeta #fields
314
315							let meta = VariantMeta {
316								#(#field_names: #field_names),*
317							};
318
319							#[allow(non_upper_case_globals)]
320							static SCHEMA: RivetErrorSchemaWithMeta<VariantMeta> = RivetErrorSchemaWithMeta {
321								schema: RivetErrorSchema {
322									group: #group,
323									code: #code,
324									default_message: #description,
325									meta_type: Some(stringify!(#enum_name::#variant_name)),
326									_macro_marker: MacroMarker { _private: () },
327								},
328								message_fn: |meta: &VariantMeta| -> String {
329									::rivet_error::indoc::formatdoc! {
330										#formatted,
331										#(#field_names = meta.#field_names),*
332									}
333								},
334								_phantom: ::std::marker::PhantomData,
335							};
336
337							SCHEMA.build_with(meta)
338						}
339					});
340				} else {
341					let json_fields = field_names
342						.iter()
343						.map(|field_name| {
344							let field_name_str = field_name.as_ref().map(|x| x.to_string());
345
346							quote! { #field_name_str: #field_name }
347						})
348						.collect::<Vec<_>>();
349
350					variant_matches.push(quote! {
351						#enum_name::#variant_name #field_patterns => {
352							use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
353
354							#[allow(non_upper_case_globals)]
355							static SCHEMA: RivetErrorSchema = RivetErrorSchema {
356								group: #group,
357								code: #code,
358								default_message: #description,
359								meta_type: Some(stringify!(#enum_name::#variant_name)),
360								_macro_marker: MacroMarker { _private: () },
361							};
362
363							let meta_value = ::serde_json::json!({
364								#(#json_fields),*
365							});
366
367							let meta_json = ::serde_json::value::to_raw_value(&meta_value).ok();
368
369							let error = RivetError {
370								kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
371								meta: meta_json,
372								message: None,
373								actor: None,
374							};
375							::anyhow::Error::new(error)
376						}
377					});
378				}
379			}
380			Fields::Unnamed(fields) => {
381				let field_count = fields.unnamed.len();
382				let field_names = (0..field_count)
383					.map(|i| {
384						syn::Ident::new(&format!("field{}", i), proc_macro2::Span::call_site())
385					})
386					.collect::<Vec<_>>();
387				let field_patterns = quote! { ( #(#field_names),* ) };
388
389				if let Some(formatted) = &formatted_desc {
390					let meta_fields = field_names
391						.iter()
392						.zip(fields.unnamed.iter())
393						.map(|(field_name, field)| quote! { #field_name: #field })
394						.collect::<Vec<_>>();
395
396					variant_matches.push(quote! {
397						#enum_name::#variant_name #field_patterns => {
398							use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
399
400							#[derive(Serialize)]
401							struct VariantMeta {
402								#(#meta_fields),*
403							}
404
405							let meta = VariantMeta {
406								#(#field_names: #field_names),*
407							};
408
409							#[allow(non_upper_case_globals)]
410							static SCHEMA: RivetErrorSchemaWithMeta<VariantMeta> = RivetErrorSchemaWithMeta {
411								schema: RivetErrorSchema {
412									group: #group,
413									code: #code,
414									default_message: #description,
415									meta_type: Some(stringify!(#enum_name::#variant_name)),
416									_macro_marker: MacroMarker { _private: () },
417								},
418								message_fn: |meta: &VariantMeta| -> String {
419									::rivet_error::indoc::formatdoc! {
420										#formatted,
421										#(meta.#field_names),*
422									}
423								},
424								_phantom: ::std::marker::PhantomData,
425							};
426
427							SCHEMA.build_with(meta)
428						}
429					});
430				} else {
431					let json_fields = field_names
432						.iter()
433						.map(|field_name| {
434							let field_name_str = field_name.to_string();
435
436							quote! { #field_name_str: #field_name }
437						})
438						.collect::<Vec<_>>();
439
440					variant_matches.push(quote! {
441						#enum_name::#variant_name #field_patterns => {
442							use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
443
444							#[allow(non_upper_case_globals)]
445							static SCHEMA: RivetErrorSchema = RivetErrorSchema {
446								group: #group,
447								code: #code,
448								default_message: #description,
449								meta_type: Some(stringify!(#enum_name::#variant_name)),
450								_macro_marker: MacroMarker { _private: () },
451							};
452
453							let meta_value = ::serde_json::json!({
454								#(#json_fields),*
455							});
456
457							let meta_json = ::serde_json::value::to_raw_value(&meta_value).ok();
458
459							let error = RivetError {
460								kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
461								meta: meta_json,
462								message: None,
463								actor: None,
464							};
465							::anyhow::Error::new(error)
466						}
467					});
468				}
469			}
470			Fields::Unit => {
471				// Handle unit variants
472				variant_matches.push(quote! {
473					#enum_name::#variant_name => {
474						use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
475
476						#[allow(non_upper_case_globals)]
477						static SCHEMA: RivetErrorSchema = RivetErrorSchema {
478							group: #group,
479							code: #code,
480							default_message: #description,
481							meta_type: None,
482							_macro_marker: MacroMarker { _private: () },
483						};
484
485						SCHEMA.build()
486					}
487				});
488			}
489		}
490	}
491
492	let output = quote! {
493		impl #enum_name {
494			#vis fn build(self) -> ::anyhow::Error {
495				match self {
496					#(#variant_matches),*
497				}
498			}
499		}
500	};
501
502	TokenStream::from(output)
503}
504
505fn parse_variant_error_args(
506	tokens: &proc_macro2::TokenStream,
507) -> syn::Result<(String, String, Option<String>)> {
508	struct VariantErrorArgs {
509		code: String,
510		description: String,
511		formatted_desc: Option<String>,
512	}
513
514	impl syn::parse::Parse for VariantErrorArgs {
515		fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
516			let code = input.parse::<LitStr>()?.value();
517			input.parse::<syn::Token![,]>()?;
518
519			let description = input.parse::<LitStr>()?.value();
520
521			let mut formatted_desc = None;
522			if input.peek(syn::Token![,]) {
523				input.parse::<syn::Token![,]>()?;
524				if input.peek(LitStr) {
525					formatted_desc = Some(input.parse::<LitStr>()?.value());
526				}
527			}
528
529			Ok(VariantErrorArgs {
530				code,
531				description,
532				formatted_desc,
533			})
534		}
535	}
536
537	let args = syn::parse2::<VariantErrorArgs>(tokens.clone())?;
538	Ok((args.code, args.description, args.formatted_desc))
539}
540
541struct ErrorArgs {
542	group: String,
543	code: String,
544	description: String,
545	formatted_desc: Option<String>,
546}
547
548impl syn::parse::Parse for ErrorArgs {
549	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
550		let group = input.parse::<LitStr>()?.value();
551		input.parse::<syn::Token![,]>()?;
552
553		let code = input.parse::<LitStr>()?.value();
554		input.parse::<syn::Token![,]>()?;
555
556		let description = input.parse::<LitStr>()?.value();
557
558		let mut formatted_desc = None;
559		// Check if there's a formatted description
560		if input.peek(syn::Token![,]) {
561			input.parse::<syn::Token![,]>()?;
562			if input.peek(LitStr) {
563				formatted_desc = Some(input.parse::<LitStr>()?.value());
564			}
565		}
566
567		Ok(ErrorArgs {
568			group,
569			code,
570			description,
571			formatted_desc,
572		})
573	}
574}
575
576fn write_error_doc(group: &str, code: &str, message: &str) -> std::io::Result<()> {
577	use std::fs;
578	use std::io::Write;
579
580	let workspace_root = find_workspace_root()?;
581	let errors_dir = if std::env::var("RIVET_ERROR_OUTPUT_DIR").is_ok() {
582		// If custom dir is specified, errors go directly there
583		workspace_root
584	} else {
585		// Otherwise use the standard out/errors path
586		workspace_root.join("engine/artifacts/errors")
587	};
588	fs::create_dir_all(&errors_dir)?;
589
590	let filename = format!("{group}.{code}.json");
591	let filepath = errors_dir.join(filename);
592
593	// Create JSON structure
594	let error_doc = serde_json::json!({
595		"group": group,
596		"code": code,
597		"message": message
598	});
599
600	let content = serde_json::to_string_pretty(&error_doc)?;
601
602	let mut file = fs::File::create(filepath)?;
603	file.write_all(content.as_bytes())?;
604
605	Ok(())
606}
607
608fn find_workspace_root() -> std::io::Result<std::path::PathBuf> {
609	use std::path::Path;
610
611	// Check if a custom output directory is specified via env var
612	if let Ok(custom_dir) = std::env::var("RIVET_ERROR_OUTPUT_DIR") {
613		return Ok(Path::new(&custom_dir).to_path_buf());
614	}
615
616	let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
617		.map_err(|e| std::io::Error::new(std::io::ErrorKind::NotFound, e))?;
618
619	let mut current = Path::new(&manifest_dir);
620
621	loop {
622		if current.join("Cargo.toml").exists() {
623			let content = std::fs::read_to_string(current.join("Cargo.toml"))?;
624			if content.contains("[workspace]") {
625				return Ok(current.to_path_buf());
626			}
627		}
628
629		match current.parent() {
630			Some(parent) => current = parent,
631			None => {
632				return Ok(Path::new(&manifest_dir).to_path_buf());
633			}
634		}
635	}
636}