1use heck::ToSnakeCase;
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{Data, DeriveInput, Field, Fields, Ident};
6
7#[proc_macro_derive(ErrorStack, attributes(source, stack_source, location))]
157pub fn derive_error_stack(input: TokenStream) -> TokenStream {
158 let input = syn::parse_macro_input!(input as DeriveInput);
159 match derive_impl(input) {
160 Ok(tokens) => tokens.into(),
161 Err(err) => err.to_compile_error().into(),
162 }
163}
164
165fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
166 let name = &input.ident;
167
168 match &input.data {
169 Data::Enum(data) => {
170 let mut constructor_methods = Vec::new();
171 let mut location_arms = Vec::new();
172 let mut stack_source_arms = Vec::new();
173
174 for variant in &data.variants {
175 let variant_name = &variant.ident;
176 let fields = match &variant.fields {
177 Fields::Named(f) => f,
178 Fields::Unnamed(_) => {
179 return Err(syn::Error::new(
180 variant_name.span(),
181 format!(
182 "ErrorStack derive requires named (struct) variants; \
183 found tuple variant `{variant_name}`"
184 ),
185 ));
186 }
187 Fields::Unit => {
188 return Err(syn::Error::new(
189 variant_name.span(),
190 format!(
191 "ErrorStack derive requires named (struct) variants; \
192 found unit variant `{variant_name}`"
193 ),
194 ));
195 }
196 };
197
198 let parsed = parse_fields(&fields.named, variant_name)?;
199
200 constructor_methods.push(gen_constructor_enum(variant_name, &parsed));
201 location_arms.push(gen_location_arm_enum(variant_name, &parsed));
202 stack_source_arms.push(gen_stack_source_arm_enum(variant_name, &parsed));
203 }
204
205 Ok(quote! {
206 impl #name {
207 #(#constructor_methods)*
208 }
209
210 impl ::errorstack::ErrorStack for #name {
211 fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
212 match self {
213 #(#location_arms)*
214 }
215 }
216
217 fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
218 match self {
219 #(#stack_source_arms)*
220 }
221 }
222 }
223 })
224 }
225
226 Data::Struct(data) => {
227 let fields = match &data.fields {
228 Fields::Named(f) => f,
229 _ => {
230 return Err(syn::Error::new(
231 name.span(),
232 "ErrorStack derive requires named fields",
233 ));
234 }
235 };
236
237 let parsed = parse_fields(&fields.named, name)?;
238 let constructor = gen_constructor_struct(name, &parsed);
239
240 let location_body = if let Some(loc) = &parsed.location {
241 let loc_ident = &loc.ident;
242 quote! { Some(self.#loc_ident) }
243 } else {
244 quote! { None }
245 };
246
247 let stack_source_body = if parsed.stack_source {
248 let src = parsed.source.as_ref().unwrap();
249 let src_ident = &src.ident;
250 if parsed.optional_source {
251 quote! { self.#src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack) }
252 } else {
253 quote! { Some(&self.#src_ident as &dyn ::errorstack::ErrorStack) }
254 }
255 } else {
256 quote! { None }
257 };
258
259 Ok(quote! {
260 impl #name {
261 #constructor
262 }
263
264 impl ::errorstack::ErrorStack for #name {
265 fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
266 #location_body
267 }
268
269 fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
270 #stack_source_body
271 }
272 }
273 })
274 }
275
276 Data::Union(_) => Err(syn::Error::new(
277 name.span(),
278 "ErrorStack derive is not supported on unions",
279 )),
280 }
281}
282
283struct ParsedFields<'a> {
284 source: Option<&'a Field>,
285 location: Option<&'a Field>,
286 stack_source: bool,
287 optional_source: bool,
288 inner_source_ty: Option<syn::Type>,
290 user_fields: Vec<&'a Field>,
291}
292
293fn attr(field: &Field, name: &str) -> bool {
294 field.attrs.iter().any(|a| a.path().is_ident(name))
295}
296
297fn extract_option_inner(ty: &syn::Type) -> Option<&syn::Type> {
299 let syn::Type::Path(type_path) = ty else {
300 return None;
301 };
302 let segment = type_path.path.segments.last()?;
303 if segment.ident != "Option" {
304 return None;
305 }
306 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
307 return None;
308 };
309 if args.args.len() != 1 {
310 return None;
311 }
312 let syn::GenericArgument::Type(inner) = args.args.first()? else {
313 return None;
314 };
315 Some(inner)
316}
317
318fn parse_fields<'a>(
319 fields: &'a syn::punctuated::Punctuated<Field, syn::Token![,]>,
320 context_name: &Ident,
321) -> syn::Result<ParsedFields<'a>> {
322 let mut source: Option<&Field> = None;
323 let mut location: Option<&Field> = None;
324 let mut stack_source = false;
325 let mut optional_source = false;
326 let mut inner_source_ty = None;
327 let mut user_fields = Vec::new();
328
329 for field in fields {
330 let ident = field.ident.as_ref().unwrap();
331 let source_by_name = ident == "source";
332 let source_by_attr = attr(field, "source");
333 let location_attr = attr(field, "location");
334 let stack_source_attr = attr(field, "stack_source");
335
336 if source_by_name || source_by_attr || stack_source_attr {
337 if source.is_some() {
338 return Err(syn::Error::new(
339 ident.span(),
340 format!("variant `{context_name}` has multiple source fields"),
341 ));
342 }
343 source = Some(field);
344 if stack_source_attr {
345 stack_source = true;
346 }
347 if let Some(inner) = extract_option_inner(&field.ty) {
348 optional_source = true;
349 inner_source_ty = Some(inner.clone());
350 }
351 } else if location_attr {
352 if location.is_some() {
353 return Err(syn::Error::new(
354 ident.span(),
355 format!("variant `{context_name}` has multiple location fields"),
356 ));
357 }
358 location = Some(field);
359 } else {
360 user_fields.push(field);
361 }
362 }
363
364 Ok(ParsedFields {
365 source,
366 location,
367 stack_source,
368 optional_source,
369 inner_source_ty,
370 user_fields,
371 })
372}
373
374struct ConstructorCtx {
376 method_name: Ident,
377 with_method_name: Ident,
378 doc: String,
379 doc_with: String,
380 self_path: TokenStream2,
382}
383
384fn gen_constructor(ctx: &ConstructorCtx, parsed: &ParsedFields<'_>) -> TokenStream2 {
385 let ConstructorCtx {
386 method_name,
387 with_method_name,
388 doc,
389 doc_with,
390 self_path,
391 } = ctx;
392
393 let user_params: Vec<_> = parsed
394 .user_fields
395 .iter()
396 .map(|f| {
397 let ident = &f.ident;
398 let ty = &f.ty;
399 quote! { #ident: #ty }
400 })
401 .collect();
402
403 let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
404
405 let location_init = parsed.location.as_ref().map(|f| {
406 let ident = &f.ident;
407 quote! { #ident: location, }
408 });
409
410 let location_capture = parsed.location.as_ref().map(|_| {
411 quote! { let location = ::std::panic::Location::caller(); }
412 });
413
414 if let Some(src) = &parsed.source {
415 let src_ident = &src.ident;
416
417 if parsed.optional_source {
418 let inner_ty = parsed.inner_source_ty.as_ref().unwrap();
419 quote! {
420 #[doc = #doc]
421 #[track_caller]
422 pub(crate) fn #method_name(#(#user_params),*) -> Self {
423 #location_capture
424 #self_path {
425 #src_ident: None,
426 #(#user_field_names,)*
427 #location_init
428 }
429 }
430
431 #[doc = #doc_with]
432 #[track_caller]
433 pub(crate) fn #with_method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#inner_ty) -> Self {
434 #location_capture
435 move |#src_ident| #self_path {
436 #src_ident: Some(#src_ident),
437 #(#user_field_names,)*
438 #location_init
439 }
440 }
441 }
442 } else if parsed.user_fields.is_empty() {
443 let src_ty = &src.ty;
444 quote! {
445 #[doc = #doc]
446 #[track_caller]
447 pub(crate) fn #method_name(#src_ident: #src_ty) -> Self {
448 #location_capture
449 #self_path {
450 #src_ident,
451 #location_init
452 }
453 }
454 }
455 } else {
456 let src_ty = &src.ty;
457 quote! {
458 #[doc = #doc]
459 #[track_caller]
460 pub(crate) fn #method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
461 #location_capture
462 move |#src_ident| #self_path {
463 #src_ident,
464 #(#user_field_names,)*
465 #location_init
466 }
467 }
468 }
469 }
470 } else {
471 quote! {
472 #[doc = #doc]
473 #[track_caller]
474 pub(crate) fn #method_name(#(#user_params),*) -> Self {
475 #location_capture
476 #self_path {
477 #(#user_field_names,)*
478 #location_init
479 }
480 }
481 }
482 }
483}
484
485fn gen_constructor_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
486 let snake = variant_name.to_string().to_snake_case();
487 let ctx = ConstructorCtx {
488 method_name: Ident::new(&snake, variant_name.span()),
489 with_method_name: Ident::new(&format!("{snake}_with"), variant_name.span()),
490 doc: format!("Constructs a [`{variant_name}`](Self::{variant_name}) error."),
491 doc_with: format!(
492 "Constructs a [`{variant_name}`](Self::{variant_name}) error with a source."
493 ),
494 self_path: quote! { Self::#variant_name },
495 };
496 gen_constructor(&ctx, parsed)
497}
498
499fn gen_constructor_struct(type_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
500 let ctx = ConstructorCtx {
501 method_name: Ident::new("new", type_name.span()),
502 with_method_name: Ident::new("new_with", type_name.span()),
503 doc: format!("Constructs a [`{type_name}`]."),
504 doc_with: format!("Constructs a [`{type_name}`] with a source."),
505 self_path: quote! { Self },
506 };
507 gen_constructor(&ctx, parsed)
508}
509
510fn gen_location_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
511 if let Some(loc) = &parsed.location {
512 let loc_ident = &loc.ident;
513 quote! {
514 Self::#variant_name { #loc_ident, .. } => Some(#loc_ident),
515 }
516 } else {
517 quote! {
518 Self::#variant_name { .. } => None,
519 }
520 }
521}
522
523fn gen_stack_source_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
524 if parsed.stack_source {
525 let src_ident = &parsed.source.unwrap().ident;
526 if parsed.optional_source {
527 quote! {
528 Self::#variant_name { #src_ident, .. } => #src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack),
529 }
530 } else {
531 quote! {
532 Self::#variant_name { #src_ident, .. } => Some(#src_ident as &dyn ::errorstack::ErrorStack),
533 }
534 }
535 } else {
536 quote! {
537 Self::#variant_name { .. } => None,
538 }
539 }
540}