1use heck::ToSnakeCase;
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{Data, DeriveInput, Field, Fields, Ident};
6
7#[proc_macro_derive(ErrorStack, attributes(source, stack_source, location))]
155pub fn derive_error_stack(input: TokenStream) -> TokenStream {
156 let input = syn::parse_macro_input!(input as DeriveInput);
157 match derive_impl(input) {
158 Ok(tokens) => tokens.into(),
159 Err(err) => err.to_compile_error().into(),
160 }
161}
162
163fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
164 let name = &input.ident;
165
166 match &input.data {
167 Data::Enum(data) => {
168 let mut constructor_methods = Vec::new();
169 let mut location_arms = Vec::new();
170 let mut stack_source_arms = Vec::new();
171
172 for variant in &data.variants {
173 let variant_name = &variant.ident;
174 let fields = match &variant.fields {
175 Fields::Named(f) => f,
176 Fields::Unnamed(_) => {
177 return Err(syn::Error::new(
178 variant_name.span(),
179 format!(
180 "ErrorStack derive requires named (struct) variants; \
181 found tuple variant `{variant_name}`"
182 ),
183 ));
184 }
185 Fields::Unit => {
186 return Err(syn::Error::new(
187 variant_name.span(),
188 format!(
189 "ErrorStack derive requires named (struct) variants; \
190 found unit variant `{variant_name}`"
191 ),
192 ));
193 }
194 };
195
196 let parsed = parse_fields(&fields.named, variant_name)?;
197
198 constructor_methods.push(gen_constructor_enum(variant_name, &parsed));
199 location_arms.push(gen_location_arm_enum(variant_name, &parsed));
200 stack_source_arms.push(gen_stack_source_arm_enum(variant_name, &parsed));
201 }
202
203 Ok(quote! {
204 impl #name {
205 #(#constructor_methods)*
206 }
207
208 impl ::errorstack::ErrorStack for #name {
209 fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
210 match self {
211 #(#location_arms)*
212 }
213 }
214
215 fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
216 match self {
217 #(#stack_source_arms)*
218 }
219 }
220 }
221 })
222 }
223
224 Data::Struct(data) => {
225 let fields = match &data.fields {
226 Fields::Named(f) => f,
227 _ => {
228 return Err(syn::Error::new(
229 name.span(),
230 "ErrorStack derive requires named fields",
231 ));
232 }
233 };
234
235 let parsed = parse_fields(&fields.named, name)?;
236 let constructor = gen_constructor_struct(name, &parsed);
237
238 let location_body = if let Some(loc) = &parsed.location {
239 let loc_ident = &loc.ident;
240 quote! { Some(self.#loc_ident) }
241 } else {
242 quote! { None }
243 };
244
245 let stack_source_body = if parsed.stack_source {
246 let src = parsed.source.as_ref().unwrap();
247 let src_ident = &src.ident;
248 if parsed.optional_source {
249 quote! { self.#src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack) }
250 } else {
251 quote! { Some(&self.#src_ident as &dyn ::errorstack::ErrorStack) }
252 }
253 } else {
254 quote! { None }
255 };
256
257 Ok(quote! {
258 impl #name {
259 #constructor
260 }
261
262 impl ::errorstack::ErrorStack for #name {
263 fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
264 #location_body
265 }
266
267 fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
268 #stack_source_body
269 }
270 }
271 })
272 }
273
274 Data::Union(_) => Err(syn::Error::new(
275 name.span(),
276 "ErrorStack derive is not supported on unions",
277 )),
278 }
279}
280
281struct ParsedFields<'a> {
282 source: Option<&'a Field>,
283 location: Option<&'a Field>,
284 stack_source: bool,
285 optional_source: bool,
286 inner_source_ty: Option<syn::Type>,
288 user_fields: Vec<&'a Field>,
289}
290
291fn attr(field: &Field, name: &str) -> bool {
292 field.attrs.iter().any(|a| a.path().is_ident(name))
293}
294
295fn extract_option_inner(ty: &syn::Type) -> Option<&syn::Type> {
297 let syn::Type::Path(type_path) = ty else {
298 return None;
299 };
300 let segment = type_path.path.segments.last()?;
301 if segment.ident != "Option" {
302 return None;
303 }
304 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
305 return None;
306 };
307 if args.args.len() != 1 {
308 return None;
309 }
310 let syn::GenericArgument::Type(inner) = args.args.first()? else {
311 return None;
312 };
313 Some(inner)
314}
315
316fn parse_fields<'a>(
317 fields: &'a syn::punctuated::Punctuated<Field, syn::Token![,]>,
318 context_name: &Ident,
319) -> syn::Result<ParsedFields<'a>> {
320 let mut source: Option<&Field> = None;
321 let mut location: Option<&Field> = None;
322 let mut stack_source = false;
323 let mut optional_source = false;
324 let mut inner_source_ty = None;
325 let mut user_fields = Vec::new();
326
327 for field in fields {
328 let ident = field.ident.as_ref().unwrap();
329 let source_by_name = ident == "source";
330 let source_by_attr = attr(field, "source");
331 let location_attr = attr(field, "location");
332 let stack_source_attr = attr(field, "stack_source");
333
334 if source_by_name || source_by_attr || stack_source_attr {
335 if source.is_some() {
336 return Err(syn::Error::new(
337 ident.span(),
338 format!("variant `{context_name}` has multiple source fields"),
339 ));
340 }
341 source = Some(field);
342 if stack_source_attr {
343 stack_source = true;
344 }
345 if let Some(inner) = extract_option_inner(&field.ty) {
346 optional_source = true;
347 inner_source_ty = Some(inner.clone());
348 }
349 } else if location_attr {
350 if location.is_some() {
351 return Err(syn::Error::new(
352 ident.span(),
353 format!("variant `{context_name}` has multiple location fields"),
354 ));
355 }
356 location = Some(field);
357 } else {
358 user_fields.push(field);
359 }
360 }
361
362 Ok(ParsedFields {
363 source,
364 location,
365 stack_source,
366 optional_source,
367 inner_source_ty,
368 user_fields,
369 })
370}
371
372struct ConstructorCtx {
374 method_name: Ident,
375 with_method_name: Ident,
376 doc: String,
377 doc_with: String,
378 self_path: TokenStream2,
380}
381
382fn gen_constructor(ctx: &ConstructorCtx, parsed: &ParsedFields<'_>) -> TokenStream2 {
383 let ConstructorCtx {
384 method_name,
385 with_method_name,
386 doc,
387 doc_with,
388 self_path,
389 } = ctx;
390
391 let user_params: Vec<_> = parsed
392 .user_fields
393 .iter()
394 .map(|f| {
395 let ident = &f.ident;
396 let ty = &f.ty;
397 quote! { #ident: #ty }
398 })
399 .collect();
400
401 let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
402
403 let location_init = parsed.location.as_ref().map(|f| {
404 let ident = &f.ident;
405 quote! { #ident: location, }
406 });
407
408 let location_capture = parsed.location.as_ref().map(|_| {
409 quote! { let location = ::std::panic::Location::caller(); }
410 });
411
412 if let Some(src) = &parsed.source {
413 let src_ident = &src.ident;
414
415 if parsed.optional_source {
416 let inner_ty = parsed.inner_source_ty.as_ref().unwrap();
417 quote! {
418 #[doc = #doc]
419 #[track_caller]
420 pub(crate) fn #method_name(#(#user_params),*) -> Self {
421 #location_capture
422 #self_path {
423 #src_ident: None,
424 #(#user_field_names,)*
425 #location_init
426 }
427 }
428
429 #[doc = #doc_with]
430 #[track_caller]
431 pub(crate) fn #with_method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#inner_ty) -> Self {
432 #location_capture
433 move |#src_ident| #self_path {
434 #src_ident: Some(#src_ident),
435 #(#user_field_names,)*
436 #location_init
437 }
438 }
439 }
440 } else {
441 let src_ty = &src.ty;
442 quote! {
443 #[doc = #doc]
444 #[track_caller]
445 pub(crate) fn #method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
446 #location_capture
447 move |#src_ident| #self_path {
448 #src_ident,
449 #(#user_field_names,)*
450 #location_init
451 }
452 }
453 }
454 }
455 } else {
456 quote! {
457 #[doc = #doc]
458 #[track_caller]
459 pub(crate) fn #method_name(#(#user_params),*) -> Self {
460 #location_capture
461 #self_path {
462 #(#user_field_names,)*
463 #location_init
464 }
465 }
466 }
467 }
468}
469
470fn gen_constructor_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
471 let snake = variant_name.to_string().to_snake_case();
472 let ctx = ConstructorCtx {
473 method_name: Ident::new(&snake, variant_name.span()),
474 with_method_name: Ident::new(&format!("{snake}_with"), variant_name.span()),
475 doc: format!("Constructs a [`{variant_name}`](Self::{variant_name}) error."),
476 doc_with: format!(
477 "Constructs a [`{variant_name}`](Self::{variant_name}) error with a source."
478 ),
479 self_path: quote! { Self::#variant_name },
480 };
481 gen_constructor(&ctx, parsed)
482}
483
484fn gen_constructor_struct(type_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
485 let ctx = ConstructorCtx {
486 method_name: Ident::new("new", type_name.span()),
487 with_method_name: Ident::new("new_with", type_name.span()),
488 doc: format!("Constructs a [`{type_name}`]."),
489 doc_with: format!("Constructs a [`{type_name}`] with a source."),
490 self_path: quote! { Self },
491 };
492 gen_constructor(&ctx, parsed)
493}
494
495fn gen_location_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
496 if let Some(loc) = &parsed.location {
497 let loc_ident = &loc.ident;
498 quote! {
499 Self::#variant_name { #loc_ident, .. } => Some(#loc_ident),
500 }
501 } else {
502 quote! {
503 Self::#variant_name { .. } => None,
504 }
505 }
506}
507
508fn gen_stack_source_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
509 if parsed.stack_source {
510 let src_ident = &parsed.source.unwrap().ident;
511 if parsed.optional_source {
512 quote! {
513 Self::#variant_name { #src_ident, .. } => #src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack),
514 }
515 } else {
516 quote! {
517 Self::#variant_name { #src_ident, .. } => Some(#src_ident as &dyn ::errorstack::ErrorStack),
518 }
519 }
520 } else {
521 quote! {
522 Self::#variant_name { .. } => None,
523 }
524 }
525}