1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{quote, ToTokens};
9use syn::{
10 parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Error, Expr, ExprLit,
11 ExprPath, Fields, Lit, LitStr, Result, Variant,
12};
13
14#[proc_macro_derive(ErrorCode, attributes(orion_error))]
15pub fn derive_error_code(input: TokenStream) -> TokenStream {
16 expand_error_code(parse_macro_input!(input as DeriveInput)).into()
17}
18
19#[proc_macro_derive(ErrorIdentityProvider, attributes(orion_error))]
20pub fn derive_error_identity_provider(input: TokenStream) -> TokenStream {
21 expand_error_identity_provider(parse_macro_input!(input as DeriveInput)).into()
22}
23
24#[proc_macro_derive(OrionError, attributes(orion_error))]
25pub fn derive_orion_error(input: TokenStream) -> TokenStream {
26 expand_orion_error(parse_macro_input!(input as DeriveInput)).into()
27}
28
29fn expand_error_code(input: DeriveInput) -> TokenStream2 {
30 match impl_error_code(input, MissingCode::Error) {
31 Ok(tokens) => tokens,
32 Err(err) => err.to_compile_error(),
33 }
34}
35
36fn expand_error_identity_provider(input: DeriveInput) -> TokenStream2 {
37 match impl_error_identity_provider(input) {
38 Ok(tokens) => tokens,
39 Err(err) => err.to_compile_error(),
40 }
41}
42
43fn expand_orion_error(input: DeriveInput) -> TokenStream2 {
44 let display = impl_display(input.clone());
45 let error_code = impl_error_code(input.clone(), MissingCode::Default);
46 let identity_provider = impl_error_identity_provider(input.clone());
47 let domain_reason = impl_domain_reason(input);
48
49 let mut out = TokenStream2::new();
50 let mut errors = Vec::new();
51 for result in [display, error_code, identity_provider, domain_reason] {
52 match result {
53 Ok(tokens) => out.extend(tokens),
54 Err(err) => errors.push(err),
55 }
56 }
57
58 match errors.into_iter().reduce(|mut first, second| {
59 first.combine(second);
60 first
61 }) {
62 Some(first) => first.to_compile_error(),
63 None => out,
64 }
65}
66
67fn impl_domain_reason(input: DeriveInput) -> Result<TokenStream2> {
68 let ident = input.ident;
69 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
70
71 match input.data {
72 Data::Enum(_) | Data::Struct(_) => Ok(quote! {
73 impl #impl_generics ::orion_error::DomainReason for #ident #ty_generics #where_clause {}
74 }),
75 Data::Union(_) => Err(Error::new(
76 ident.span(),
77 "OrionError can only be derived for enums or structs",
78 )),
79 }
80}
81
82fn impl_display(input: DeriveInput) -> Result<TokenStream2> {
83 let ident = input.ident;
84 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
85
86 match input.data {
87 Data::Enum(data) => {
88 let arms = data
89 .variants
90 .iter()
91 .map(|variant| {
92 let args = OrionAttrs::from_attrs(&variant.attrs)?;
93 if args.transparent {
94 let (pattern, inner) = transparent_variant_pattern(variant)?;
95 Ok(quote! {
96 #pattern => ::std::fmt::Display::fmt(#inner, f)
97 })
98 } else if let Some(message) = args.display_message() {
99 let pattern = variant_pattern(variant);
100 Ok(quote! {
101 #pattern => f.write_str(#message)
102 })
103 } else {
104 Err(Error::new(
105 variant.span(),
106 "missing #[orion_error(message = ...)] or string literal #[orion_error(identity = ...)]",
107 ))
108 }
109 })
110 .collect::<Result<Vec<_>>>()?;
111
112 Ok(quote! {
113 impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
114 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
115 match self {
116 #(#arms,)*
117 }
118 }
119 }
120 })
121 }
122 Data::Struct(data) => {
123 let args = OrionAttrs::from_attrs(&input.attrs)?;
124 let body = if args.transparent {
125 let inner = transparent_struct_binding(&data.fields)?;
126 let pattern = struct_pattern(&ident, &data.fields);
127 quote! {
128 match self {
129 #pattern => ::std::fmt::Display::fmt(#inner, f),
130 }
131 }
132 } else if let Some(message) = args.display_message() {
133 quote! { f.write_str(#message) }
134 } else {
135 return Err(Error::new(
136 ident.span(),
137 "missing container #[orion_error(message = ...)] or string literal #[orion_error(identity = ...)]",
138 ));
139 };
140
141 Ok(quote! {
142 impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
143 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
144 #body
145 }
146 }
147 })
148 }
149 Data::Union(_) => Err(Error::new(
150 ident.span(),
151 "OrionError can only be derived for enums or structs",
152 )),
153 }
154}
155
156#[derive(Clone, Copy)]
157enum MissingCode {
158 Error,
159 Default,
160}
161
162fn impl_error_code(input: DeriveInput, missing_code: MissingCode) -> Result<TokenStream2> {
163 let ident = input.ident;
164 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
165
166 match input.data {
167 Data::Enum(data) => {
168 let arms = data
169 .variants
170 .iter()
171 .map(|variant| {
172 let args = OrionAttrs::from_attrs(&variant.attrs)?;
173 if args.transparent {
174 let (pattern, inner) = transparent_variant_pattern(variant)?;
175 Ok(quote! {
176 #pattern => ::orion_error::ErrorCode::error_code(#inner)
177 })
178 } else if let Some(code) = args.code {
179 let pattern = variant_pattern(variant);
180 Ok(quote! {
181 #pattern => #code
182 })
183 } else if matches!(missing_code, MissingCode::Default) {
184 let pattern = variant_pattern(variant);
185 Ok(quote! {
186 #pattern => 500
187 })
188 } else {
189 Err(Error::new(
190 variant.span(),
191 "missing #[orion_error(code = ...)] or #[orion_error(transparent)]",
192 ))
193 }
194 })
195 .collect::<Result<Vec<_>>>()?;
196
197 Ok(quote! {
198 impl #impl_generics ::orion_error::ErrorCode for #ident #ty_generics #where_clause {
199 fn error_code(&self) -> i32 {
200 match self {
201 #(#arms,)*
202 }
203 }
204 }
205 })
206 }
207 Data::Struct(data) => {
208 let args = OrionAttrs::from_attrs(&input.attrs)?;
209 let body = if args.transparent {
210 let inner = transparent_struct_binding(&data.fields)?;
211 let pattern = struct_pattern(&ident, &data.fields);
212 quote! {
213 match self {
214 #pattern => ::orion_error::ErrorCode::error_code(#inner),
215 }
216 }
217 } else if let Some(code) = args.code {
218 quote! { #code }
219 } else if matches!(missing_code, MissingCode::Default) {
220 quote! { 500 }
221 } else {
222 return Err(Error::new(
223 ident.span(),
224 "missing container #[orion_error(code = ...)] or #[orion_error(transparent)]",
225 ));
226 };
227
228 Ok(quote! {
229 impl #impl_generics ::orion_error::ErrorCode for #ident #ty_generics #where_clause {
230 fn error_code(&self) -> i32 {
231 #body
232 }
233 }
234 })
235 }
236 Data::Union(_) => Err(Error::new(
237 ident.span(),
238 "ErrorCode can only be derived for enums or structs",
239 )),
240 }
241}
242
243fn impl_error_identity_provider(input: DeriveInput) -> Result<TokenStream2> {
244 let ident = input.ident;
245 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
246
247 match input.data {
248 Data::Enum(data) => {
249 let stable_arms = data
250 .variants
251 .iter()
252 .map(|variant| {
253 let args = OrionAttrs::from_attrs(&variant.attrs)?;
254 if args.transparent {
255 let (pattern, inner) = transparent_variant_pattern(variant)?;
256 Ok(quote! {
257 #pattern => ::orion_error::ErrorIdentityProvider::stable_code(#inner)
258 })
259 } else if let Some(identity) = args.identity {
260 let pattern = variant_pattern(variant);
261 Ok(quote! {
262 #pattern => #identity
263 })
264 } else {
265 Err(Error::new(
266 variant.span(),
267 "missing #[orion_error(identity = ...)] or #[orion_error(transparent)]",
268 ))
269 }
270 })
271 .collect::<Result<Vec<_>>>()?;
272
273 let category_arms = data
274 .variants
275 .iter()
276 .map(|variant| {
277 let args = OrionAttrs::from_attrs(&variant.attrs)?;
278 if args.transparent {
279 let (pattern, inner) = transparent_variant_pattern(variant)?;
280 Ok(quote! {
281 #pattern => ::orion_error::ErrorIdentityProvider::error_category(#inner)
282 })
283 } else if let Some(category) = args.error_category()? {
284 let pattern = variant_pattern(variant);
285 Ok(quote! {
286 #pattern => #category
287 })
288 } else {
289 Err(Error::new(
290 variant.span(),
291 "missing #[orion_error(category = ...)] or category-prefixed string literal #[orion_error(identity = ...)]",
292 ))
293 }
294 })
295 .collect::<Result<Vec<_>>>()?;
296
297 Ok(quote! {
298 impl #impl_generics ::orion_error::ErrorIdentityProvider for #ident #ty_generics #where_clause {
299 fn stable_code(&self) -> &'static str {
300 match self {
301 #(#stable_arms,)*
302 }
303 }
304
305 fn error_category(&self) -> ::orion_error::ErrorCategory {
306 match self {
307 #(#category_arms,)*
308 }
309 }
310 }
311 })
312 }
313 Data::Struct(data) => {
314 let args = OrionAttrs::from_attrs(&input.attrs)?;
315 let (stable_body, category_body) = if args.transparent {
316 let inner = transparent_struct_binding(&data.fields)?;
317 let pattern = struct_pattern(&ident, &data.fields);
318 (
319 quote! {
320 match self {
321 #pattern => ::orion_error::ErrorIdentityProvider::stable_code(#inner),
322 }
323 },
324 quote! {
325 match self {
326 #pattern => ::orion_error::ErrorIdentityProvider::error_category(#inner),
327 }
328 },
329 )
330 } else {
331 let identity = args.identity.clone().ok_or_else(|| {
332 Error::new(
333 ident.span(),
334 "missing container #[orion_error(identity = ...)]",
335 )
336 })?;
337 let category = args.error_category()?.ok_or_else(|| {
338 Error::new(
339 ident.span(),
340 "missing container #[orion_error(category = ...)] or category-prefixed string literal #[orion_error(identity = ...)]",
341 )
342 })?;
343 (quote! { #identity }, quote! { #category })
344 };
345
346 Ok(quote! {
347 impl #impl_generics ::orion_error::ErrorIdentityProvider for #ident #ty_generics #where_clause {
348 fn stable_code(&self) -> &'static str {
349 #stable_body
350 }
351
352 fn error_category(&self) -> ::orion_error::ErrorCategory {
353 #category_body
354 }
355 }
356 })
357 }
358 Data::Union(_) => Err(Error::new(
359 ident.span(),
360 "ErrorIdentityProvider can only be derived for enums or structs",
361 )),
362 }
363}
364
365#[derive(Default)]
366struct OrionAttrs {
367 message: Option<Expr>,
368 code: Option<Expr>,
369 identity: Option<Expr>,
370 category: Option<TokenStream2>,
371 transparent: bool,
372}
373
374impl OrionAttrs {
375 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
376 let mut out = Self::default();
377 for attr in attrs {
378 if !attr.path().is_ident("orion_error") {
379 continue;
380 }
381
382 attr.parse_nested_meta(|meta| {
383 if meta.path.is_ident("transparent") {
384 out.transparent = true;
385 return Ok(());
386 }
387
388 if meta.path.is_ident("code") {
389 out.code = Some(meta.value()?.parse()?);
390 return Ok(());
391 }
392
393 if meta.path.is_ident("message") {
394 out.message = Some(meta.value()?.parse()?);
395 return Ok(());
396 }
397
398 if meta.path.is_ident("identity") {
399 out.identity = Some(meta.value()?.parse()?);
400 return Ok(());
401 }
402
403 if meta.path.is_ident("category") {
404 let expr: Expr = meta.value()?.parse()?;
405 out.category = Some(category_expr(expr)?);
406 return Ok(());
407 }
408
409 Err(meta.error("unsupported orion_error attribute"))
410 })?;
411 }
412 Ok(out)
413 }
414
415 fn display_message(&self) -> Option<LitStr> {
416 self.message
417 .as_ref()
418 .and_then(expr_lit_str)
419 .cloned()
420 .or_else(|| {
421 self.identity
422 .as_ref()
423 .and_then(expr_lit_str)
424 .map(message_from_identity)
425 })
426 }
427
428 fn error_category(&self) -> Result<Option<TokenStream2>> {
429 if let Some(category) = self.category.clone() {
430 return Ok(Some(category));
431 }
432
433 let Some(identity) = self.identity.as_ref().and_then(expr_lit_str) else {
434 return Ok(None);
435 };
436
437 identity_category(identity).transpose()
438 }
439}
440
441fn expr_lit_str(expr: &Expr) -> Option<&LitStr> {
442 match expr {
443 Expr::Lit(ExprLit {
444 lit: Lit::Str(lit), ..
445 }) => Some(lit),
446 _ => None,
447 }
448}
449
450fn message_from_identity(identity: &LitStr) -> LitStr {
451 let message = identity
452 .value()
453 .rsplit('.')
454 .next()
455 .unwrap_or_default()
456 .replace('_', " ");
457 LitStr::new(&message, identity.span())
458}
459
460fn identity_category(identity: &LitStr) -> Option<Result<TokenStream2>> {
461 let value = identity.value();
462 let prefix = value.split('.').next().unwrap_or_default();
463 match prefix {
464 "conf" => Some(Ok(quote! { ::orion_error::ErrorCategory::Conf })),
465 "biz" => Some(Ok(quote! { ::orion_error::ErrorCategory::Biz })),
466 "logic" => Some(Ok(quote! { ::orion_error::ErrorCategory::Logic })),
467 "sys" => Some(Ok(quote! { ::orion_error::ErrorCategory::Sys })),
468 value => Some(Err(Error::new(
469 identity.span(),
470 format!(
471 "unknown identity category prefix `{value}`; expected one of: conf, biz, logic, sys"
472 ),
473 ))),
474 }
475}
476
477fn category_expr(expr: Expr) -> Result<TokenStream2> {
478 match expr {
479 Expr::Lit(ExprLit {
480 lit: Lit::Str(lit), ..
481 }) => match lit.value().as_str() {
482 "conf" => Ok(quote! { ::orion_error::ErrorCategory::Conf }),
483 "biz" => Ok(quote! { ::orion_error::ErrorCategory::Biz }),
484 "logic" => Ok(quote! { ::orion_error::ErrorCategory::Logic }),
485 "sys" => Ok(quote! { ::orion_error::ErrorCategory::Sys }),
486 value => Err(Error::new(
487 lit.span(),
488 format!("unknown error category `{value}`; expected one of: conf, biz, logic, sys"),
489 )),
490 },
491 Expr::Path(ExprPath { path, .. })
492 if path.leading_colon.is_none() && path.segments.len() == 1 =>
493 {
494 let ident = &path.segments[0].ident;
495 match ident.to_string().as_str() {
496 "Conf" => Ok(quote! { ::orion_error::ErrorCategory::Conf }),
497 "Biz" => Ok(quote! { ::orion_error::ErrorCategory::Biz }),
498 "Logic" => Ok(quote! { ::orion_error::ErrorCategory::Logic }),
499 "Sys" => Ok(quote! { ::orion_error::ErrorCategory::Sys }),
500 _ => Ok(path.to_token_stream()),
501 }
502 }
503 other => Ok(other.to_token_stream()),
504 }
505}
506
507fn variant_pattern(variant: &Variant) -> TokenStream2 {
508 let ident = &variant.ident;
509 match &variant.fields {
510 Fields::Unit => quote! { Self::#ident },
511 Fields::Unnamed(_) => quote! { Self::#ident(..) },
512 Fields::Named(_) => quote! { Self::#ident { .. } },
513 }
514}
515
516fn transparent_variant_pattern(variant: &Variant) -> Result<(TokenStream2, TokenStream2)> {
517 let ident = &variant.ident;
518 match &variant.fields {
519 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
520 Ok((quote! { Self::#ident(__inner) }, quote! { __inner }))
521 }
522 Fields::Named(fields) if fields.named.len() == 1 => {
523 let field = fields
524 .named
525 .iter()
526 .next()
527 .and_then(|field| field.ident.as_ref())
528 .unwrap();
529 Ok((quote! { Self::#ident { #field } }, quote! { #field }))
530 }
531 _ => Err(Error::new(
532 variant.span(),
533 "#[orion_error(transparent)] requires exactly one field",
534 )),
535 }
536}
537
538fn transparent_struct_binding(fields: &Fields) -> Result<TokenStream2> {
539 match fields {
540 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(quote! { __inner }),
541 Fields::Named(fields) if fields.named.len() == 1 => {
542 let field = fields
543 .named
544 .iter()
545 .next()
546 .and_then(|field| field.ident.as_ref())
547 .unwrap();
548 Ok(quote! { #field })
549 }
550 _ => Err(Error::new(
551 fields.span(),
552 "#[orion_error(transparent)] requires exactly one field",
553 )),
554 }
555}
556
557fn struct_pattern(ident: &syn::Ident, fields: &Fields) -> TokenStream2 {
558 match fields {
559 Fields::Unit => quote! { #ident },
560 Fields::Unnamed(_) => quote! { #ident(__inner) },
561 Fields::Named(fields) => {
562 let field = fields
563 .named
564 .iter()
565 .next()
566 .and_then(|field| field.ident.as_ref());
567 quote! { #ident { #field } }
568 }
569 }
570}