1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{DeriveInput, Fields, LitStr, Meta, parse_macro_input};
4
5#[proc_macro_derive(RivetError, attributes(error))]
6pub fn derive_rivet_error(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8
9 match &input.data.clone() {
11 syn::Data::Struct(data_struct) => derive_struct_error(input, data_struct),
12 syn::Data::Enum(data_enum) => derive_enum_error(input, data_enum),
13 _ => panic!("RivetError can only be derived for structs and enums"),
14 }
15}
16
17fn derive_struct_error(input: DeriveInput, data_struct: &syn::DataStruct) -> TokenStream {
18 let struct_name = &input.ident;
19 let vis = &input.vis;
20
21 let error_attr = input
23 .attrs
24 .iter()
25 .find(|attr| attr.path().is_ident("error"))
26 .expect("RivetError requires #[error(...)] attribute");
27
28 let args = match &error_attr.meta {
29 Meta::List(meta_list) => {
30 let tokens = &meta_list.tokens;
31 syn::parse2::<ErrorArgs>(tokens.clone())
32 .expect("Failed to parse error attribute arguments")
33 }
34 _ => panic!("error attribute must be in the form #[error(...)]"),
35 };
36
37 let group = &args.group;
39 let code = &args.code;
40 let description = &args.description;
41
42 let output = match &data_struct.fields {
44 Fields::Named(fields) => {
45 let field_names = fields.named.iter().map(|f| &f.ident).collect::<Vec<_>>();
46
47 if let Some(formatted) = &args.formatted_desc {
48 quote! {
49 impl #struct_name {
50 #vis fn build(self) -> ::anyhow::Error {
51 use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
52
53 #[allow(non_upper_case_globals)]
54 static SCHEMA: RivetErrorSchemaWithMeta<#struct_name> = RivetErrorSchemaWithMeta {
55 schema: RivetErrorSchema {
56 group: #group,
57 code: #code,
58 default_message: #description,
59 meta_type: Some(stringify!(#struct_name)),
60 _macro_marker: MacroMarker { _private: () },
61 },
62 message_fn: |meta: &#struct_name| -> String {
63 ::rivet_error::indoc::formatdoc! {
64 #formatted,
65 #(#field_names = meta.#field_names),*
66 }
67 },
68 _phantom: ::std::marker::PhantomData,
69 };
70
71 SCHEMA.build_with(self)
72 }
73 }
74 }
75 } else {
76 quote! {
77 impl #struct_name {
78 #vis fn build(self) -> ::anyhow::Error {
79 use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
80
81 #[allow(non_upper_case_globals)]
82 static SCHEMA: RivetErrorSchema = RivetErrorSchema {
83 group: #group,
84 code: #code,
85 default_message: #description,
86 meta_type: Some(stringify!(#struct_name)),
87 _macro_marker: MacroMarker { _private: () },
88 };
89
90 let meta_json = ::serde_json::to_value(&self)
91 .ok()
92 .and_then(|v| ::serde_json::value::to_raw_value(&v).ok());
93
94 let error = RivetError {
95 kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
96 meta: meta_json,
97 message: None,
98 actor: None,
99 };
100 ::anyhow::Error::new(error)
101 }
102 }
103 }
104 }
105 }
106 Fields::Unnamed(fields) => {
107 let field_count = fields.unnamed.len();
108 let field_names = (0..field_count)
109 .map(|i| syn::Ident::new(&format!("field{}", i), proc_macro2::Span::call_site()))
110 .collect::<Vec<_>>();
111
112 if let Some(formatted) = &args.formatted_desc {
113 let struct_meta_fields = field_names
114 .iter()
115 .zip(fields.unnamed.iter())
116 .map(|(field_name, field)| {
117 let field_type = &field.ty;
118 quote! { #field_name: #field_type }
119 })
120 .collect::<Vec<_>>();
121 let meta_fields = field_names
122 .iter()
123 .enumerate()
124 .map(|(i, field_name)| {
125 let idx = syn::Index::from(i);
126 quote! { #field_name: self.#idx }
127 })
128 .collect::<Vec<_>>();
129
130 quote! {
131 impl #struct_name {
132 #vis fn build(self) -> ::anyhow::Error {
133 use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
134
135 #[derive(::serde::Serialize)]
136 struct StructMeta {
137 #(#struct_meta_fields),*
138 }
139
140 let meta = StructMeta {
141 #(#meta_fields),*
142 };
143
144 #[allow(non_upper_case_globals)]
145 static SCHEMA: RivetErrorSchemaWithMeta<StructMeta> = RivetErrorSchemaWithMeta {
146 schema: RivetErrorSchema {
147 group: #group,
148 code: #code,
149 default_message: #description,
150 meta_type: Some(stringify!(#struct_name)),
151 _macro_marker: MacroMarker { _private: () },
152 },
153 message_fn: |meta: &StructMeta| -> String {
154 ::rivet_error::indoc::formatdoc! {
155 #formatted,
156 #(meta.#field_names),*
157 }
158 },
159 _phantom: ::std::marker::PhantomData,
160 };
161
162 SCHEMA.build_with(meta)
163 }
164 }
165 }
166 } else {
167 let json_fields = field_names
168 .iter()
169 .map(|field_name| {
170 let field_name_str = field_name.to_string();
171
172 quote! { #field_name_str: #field_name }
173 })
174 .collect::<Vec<_>>();
175
176 quote! {
177 impl #struct_name {
178 #vis fn build(self) -> ::anyhow::Error {
179 use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
180
181 #[allow(non_upper_case_globals)]
182 static SCHEMA: RivetErrorSchema = RivetErrorSchema {
183 group: #group,
184 code: #code,
185 default_message: #description,
186 meta_type: Some(stringify!(#struct_name)),
187 _macro_marker: MacroMarker { _private: () },
188 };
189
190 let meta_value = ::serde_json::json!({
191 #(#json_fields),*
192 });
193
194 let meta_json = ::serde_json::value::to_raw_value(&meta_value).ok();
195
196 let error = RivetError {
197 kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
198 meta: meta_json,
199 message: None,
200 actor: None,
201 };
202 ::anyhow::Error::new(error)
203 }
204 }
205 }
206 }
207 }
208 Fields::Unit => {
209 quote! {
210 impl #struct_name {
211 #vis fn build(self) -> ::anyhow::Error {
212 use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
213
214 #[allow(non_upper_case_globals)]
215 static SCHEMA: RivetErrorSchema = RivetErrorSchema {
216 group: #group,
217 code: #code,
218 default_message: #description,
219 meta_type: None,
220 _macro_marker: MacroMarker { _private: () },
221 };
222
223 SCHEMA.build()
224 }
225 }
226 }
227 }
228 };
229
230 if let Err(e) = write_error_doc(&args.group, &args.code, &args.description) {
232 panic!(
233 "Failed to write error documentation for {}.{}: {}",
234 args.group, args.code, e
235 );
236 }
237
238 TokenStream::from(output)
241}
242
243fn derive_enum_error(input: DeriveInput, data_enum: &syn::DataEnum) -> TokenStream {
244 let enum_name = &input.ident;
245 let vis = &input.vis;
246
247 let error_attr = input
249 .attrs
250 .iter()
251 .find(|attr| attr.path().is_ident("error"))
252 .expect("RivetError on enum requires #[error(\"group\")] attribute");
253
254 let group = match &error_attr.meta {
255 Meta::List(meta_list) => {
256 let tokens = &meta_list.tokens;
257 let group_str = syn::parse2::<LitStr>(tokens.clone())
258 .expect("Failed to parse enum error attribute arguments");
259 group_str.value()
260 }
261 _ => panic!("error attribute for enum must be in the form #[error(\"group\")]"),
262 };
263
264 let mut variant_matches = Vec::new();
265
266 for variant in &data_enum.variants {
268 let variant_name = &variant.ident;
269
270 let variant_error_attr = variant
272 .attrs
273 .iter()
274 .find(|attr| attr.path().is_ident("error"))
275 .expect(&format!(
276 "Variant {} requires #[error(...)] attribute",
277 variant_name
278 ));
279
280 let (code, description, formatted_desc) = match &variant_error_attr.meta {
281 Meta::List(meta_list) => {
282 let tokens = &meta_list.tokens;
283 parse_variant_error_args(tokens).expect(&format!(
284 "Failed to parse variant error attributes for {}",
285 variant_name
286 ))
287 }
288 _ => panic!(
289 "error attribute for variant must be in the form #[error(\"code\", \"description\")]"
290 ),
291 };
292
293 if let Err(e) = write_error_doc(&group, &code, &description) {
295 panic!(
296 "Failed to write error documentation for {}.{}: {}",
297 group, code, e
298 );
299 }
300
301 match &variant.fields {
303 Fields::Named(fields) => {
304 let field_names = fields.named.iter().map(|f| &f.ident).collect::<Vec<_>>();
305 let field_patterns = quote! { { #(#field_names),* } };
306
307 if let Some(formatted) = &formatted_desc {
308 variant_matches.push(quote! {
309 #enum_name::#variant_name #field_patterns => {
310 use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
311
312 #[derive(Serialize)]
313 struct VariantMeta #fields
314
315 let meta = VariantMeta {
316 #(#field_names: #field_names),*
317 };
318
319 #[allow(non_upper_case_globals)]
320 static SCHEMA: RivetErrorSchemaWithMeta<VariantMeta> = RivetErrorSchemaWithMeta {
321 schema: RivetErrorSchema {
322 group: #group,
323 code: #code,
324 default_message: #description,
325 meta_type: Some(stringify!(#enum_name::#variant_name)),
326 _macro_marker: MacroMarker { _private: () },
327 },
328 message_fn: |meta: &VariantMeta| -> String {
329 ::rivet_error::indoc::formatdoc! {
330 #formatted,
331 #(#field_names = meta.#field_names),*
332 }
333 },
334 _phantom: ::std::marker::PhantomData,
335 };
336
337 SCHEMA.build_with(meta)
338 }
339 });
340 } else {
341 let json_fields = field_names
342 .iter()
343 .map(|field_name| {
344 let field_name_str = field_name.as_ref().map(|x| x.to_string());
345
346 quote! { #field_name_str: #field_name }
347 })
348 .collect::<Vec<_>>();
349
350 variant_matches.push(quote! {
351 #enum_name::#variant_name #field_patterns => {
352 use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
353
354 #[allow(non_upper_case_globals)]
355 static SCHEMA: RivetErrorSchema = RivetErrorSchema {
356 group: #group,
357 code: #code,
358 default_message: #description,
359 meta_type: Some(stringify!(#enum_name::#variant_name)),
360 _macro_marker: MacroMarker { _private: () },
361 };
362
363 let meta_value = ::serde_json::json!({
364 #(#json_fields),*
365 });
366
367 let meta_json = ::serde_json::value::to_raw_value(&meta_value).ok();
368
369 let error = RivetError {
370 kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
371 meta: meta_json,
372 message: None,
373 actor: None,
374 };
375 ::anyhow::Error::new(error)
376 }
377 });
378 }
379 }
380 Fields::Unnamed(fields) => {
381 let field_count = fields.unnamed.len();
382 let field_names = (0..field_count)
383 .map(|i| {
384 syn::Ident::new(&format!("field{}", i), proc_macro2::Span::call_site())
385 })
386 .collect::<Vec<_>>();
387 let field_patterns = quote! { ( #(#field_names),* ) };
388
389 if let Some(formatted) = &formatted_desc {
390 let meta_fields = field_names
391 .iter()
392 .zip(fields.unnamed.iter())
393 .map(|(field_name, field)| quote! { #field_name: #field })
394 .collect::<Vec<_>>();
395
396 variant_matches.push(quote! {
397 #enum_name::#variant_name #field_patterns => {
398 use ::rivet_error::{RivetError, RivetErrorSchema, RivetErrorSchemaWithMeta, MacroMarker};
399
400 #[derive(Serialize)]
401 struct VariantMeta {
402 #(#meta_fields),*
403 }
404
405 let meta = VariantMeta {
406 #(#field_names: #field_names),*
407 };
408
409 #[allow(non_upper_case_globals)]
410 static SCHEMA: RivetErrorSchemaWithMeta<VariantMeta> = RivetErrorSchemaWithMeta {
411 schema: RivetErrorSchema {
412 group: #group,
413 code: #code,
414 default_message: #description,
415 meta_type: Some(stringify!(#enum_name::#variant_name)),
416 _macro_marker: MacroMarker { _private: () },
417 },
418 message_fn: |meta: &VariantMeta| -> String {
419 ::rivet_error::indoc::formatdoc! {
420 #formatted,
421 #(meta.#field_names),*
422 }
423 },
424 _phantom: ::std::marker::PhantomData,
425 };
426
427 SCHEMA.build_with(meta)
428 }
429 });
430 } else {
431 let json_fields = field_names
432 .iter()
433 .map(|field_name| {
434 let field_name_str = field_name.to_string();
435
436 quote! { #field_name_str: #field_name }
437 })
438 .collect::<Vec<_>>();
439
440 variant_matches.push(quote! {
441 #enum_name::#variant_name #field_patterns => {
442 use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
443
444 #[allow(non_upper_case_globals)]
445 static SCHEMA: RivetErrorSchema = RivetErrorSchema {
446 group: #group,
447 code: #code,
448 default_message: #description,
449 meta_type: Some(stringify!(#enum_name::#variant_name)),
450 _macro_marker: MacroMarker { _private: () },
451 };
452
453 let meta_value = ::serde_json::json!({
454 #(#json_fields),*
455 });
456
457 let meta_json = ::serde_json::value::to_raw_value(&meta_value).ok();
458
459 let error = RivetError {
460 kind: rivet_error::RivetErrorKind::Static(&SCHEMA),
461 meta: meta_json,
462 message: None,
463 actor: None,
464 };
465 ::anyhow::Error::new(error)
466 }
467 });
468 }
469 }
470 Fields::Unit => {
471 variant_matches.push(quote! {
473 #enum_name::#variant_name => {
474 use ::rivet_error::{RivetError, RivetErrorSchema, MacroMarker};
475
476 #[allow(non_upper_case_globals)]
477 static SCHEMA: RivetErrorSchema = RivetErrorSchema {
478 group: #group,
479 code: #code,
480 default_message: #description,
481 meta_type: None,
482 _macro_marker: MacroMarker { _private: () },
483 };
484
485 SCHEMA.build()
486 }
487 });
488 }
489 }
490 }
491
492 let output = quote! {
493 impl #enum_name {
494 #vis fn build(self) -> ::anyhow::Error {
495 match self {
496 #(#variant_matches),*
497 }
498 }
499 }
500 };
501
502 TokenStream::from(output)
503}
504
505fn parse_variant_error_args(
506 tokens: &proc_macro2::TokenStream,
507) -> syn::Result<(String, String, Option<String>)> {
508 struct VariantErrorArgs {
509 code: String,
510 description: String,
511 formatted_desc: Option<String>,
512 }
513
514 impl syn::parse::Parse for VariantErrorArgs {
515 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
516 let code = input.parse::<LitStr>()?.value();
517 input.parse::<syn::Token![,]>()?;
518
519 let description = input.parse::<LitStr>()?.value();
520
521 let mut formatted_desc = None;
522 if input.peek(syn::Token![,]) {
523 input.parse::<syn::Token![,]>()?;
524 if input.peek(LitStr) {
525 formatted_desc = Some(input.parse::<LitStr>()?.value());
526 }
527 }
528
529 Ok(VariantErrorArgs {
530 code,
531 description,
532 formatted_desc,
533 })
534 }
535 }
536
537 let args = syn::parse2::<VariantErrorArgs>(tokens.clone())?;
538 Ok((args.code, args.description, args.formatted_desc))
539}
540
541struct ErrorArgs {
542 group: String,
543 code: String,
544 description: String,
545 formatted_desc: Option<String>,
546}
547
548impl syn::parse::Parse for ErrorArgs {
549 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
550 let group = input.parse::<LitStr>()?.value();
551 input.parse::<syn::Token![,]>()?;
552
553 let code = input.parse::<LitStr>()?.value();
554 input.parse::<syn::Token![,]>()?;
555
556 let description = input.parse::<LitStr>()?.value();
557
558 let mut formatted_desc = None;
559 if input.peek(syn::Token![,]) {
561 input.parse::<syn::Token![,]>()?;
562 if input.peek(LitStr) {
563 formatted_desc = Some(input.parse::<LitStr>()?.value());
564 }
565 }
566
567 Ok(ErrorArgs {
568 group,
569 code,
570 description,
571 formatted_desc,
572 })
573 }
574}
575
576fn write_error_doc(group: &str, code: &str, message: &str) -> std::io::Result<()> {
577 use std::fs;
578 use std::io::Write;
579
580 let workspace_root = find_workspace_root()?;
581 let errors_dir = if std::env::var("RIVET_ERROR_OUTPUT_DIR").is_ok() {
582 workspace_root
584 } else {
585 workspace_root.join("engine/artifacts/errors")
587 };
588 fs::create_dir_all(&errors_dir)?;
589
590 let filename = format!("{group}.{code}.json");
591 let filepath = errors_dir.join(filename);
592
593 let error_doc = serde_json::json!({
595 "group": group,
596 "code": code,
597 "message": message
598 });
599
600 let content = serde_json::to_string_pretty(&error_doc)?;
601
602 let mut file = fs::File::create(filepath)?;
603 file.write_all(content.as_bytes())?;
604
605 Ok(())
606}
607
608fn find_workspace_root() -> std::io::Result<std::path::PathBuf> {
609 use std::path::Path;
610
611 if let Ok(custom_dir) = std::env::var("RIVET_ERROR_OUTPUT_DIR") {
613 return Ok(Path::new(&custom_dir).to_path_buf());
614 }
615
616 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
617 .map_err(|e| std::io::Error::new(std::io::ErrorKind::NotFound, e))?;
618
619 let mut current = Path::new(&manifest_dir);
620
621 loop {
622 if current.join("Cargo.toml").exists() {
623 let content = std::fs::read_to_string(current.join("Cargo.toml"))?;
624 if content.contains("[workspace]") {
625 return Ok(current.to_path_buf());
626 }
627 }
628
629 match current.parent() {
630 Some(parent) => current = parent,
631 None => {
632 return Ok(Path::new(&manifest_dir).to_path_buf());
633 }
634 }
635 }
636}