1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6 Attribute, Data, DeriveInput, Error, Fields, GenericArgument, GenericParam, Ident,
7 PathArguments, Type, parse_macro_input, parse_quote,
8};
9
10#[proc_macro_derive(Outcome, attributes(outcome))]
11pub fn derive_outcome(input: TokenStream) -> TokenStream {
12 expand(parse_macro_input!(input as DeriveInput))
13 .unwrap_or_else(Error::into_compile_error)
14 .into()
15}
16
17fn expand(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
18 let enum_ident = input.ident;
19 let Data::Enum(data) = input.data else {
20 return Err(Error::new_spanned(
21 enum_ident,
22 "Outcome can only be derived for enums",
23 ));
24 };
25
26 let mut success = None;
27 let mut short = None;
28
29 for variant in data.variants {
30 if has_flag(&variant.attrs, "success")? {
31 if success.is_some() {
32 return Err(Error::new_spanned(
33 variant.ident,
34 "multiple success variants",
35 ));
36 }
37 success = Some((variant.ident, single_field_type(&variant.fields)?));
38 } else if has_flag(&variant.attrs, "short_circuit")? {
39 if short.is_some() {
40 return Err(Error::new_spanned(
41 variant.ident,
42 "multiple short_circuit variants",
43 ));
44 }
45 short = Some((variant.ident, single_field_type(&variant.fields)?));
46 }
47 }
48
49 let (success_ident, output_ty) =
50 success.ok_or_else(|| Error::new_spanned(&enum_ident, "missing #[outcome(success)]"))?;
51 let (short_ident, residual_ty) = short
52 .ok_or_else(|| Error::new_spanned(&enum_ident, "missing #[outcome(short_circuit)]"))?;
53 let result_interop = result_interop(&input.attrs)?;
54 let residual_ident = format_ident!("{enum_ident}Residual");
55 let mut result_generics = input.generics.clone();
56 result_generics
57 .params
58 .push(GenericParam::Type(parse_quote!(E)));
59 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
60 let (result_impl_generics, _, _) = result_generics.split_for_impl();
61 let residual_generic = residual_type_generic(&residual_ty);
62
63 let result_impl = result_interop.map(|_| {
64 quote! {
65 impl #result_impl_generics ::core::ops::FromResidual<#residual_ident<#residual_ty>>
66 for ::core::result::Result<#output_ty, E>
67 where
68 E: ::core::convert::From<#residual_ty>,
69 {
70 fn from_residual(residual: #residual_ident<#residual_ty>) -> Self {
71 ::core::result::Result::Err(residual.0.into())
72 }
73 }
74 }
75 });
76
77 Ok(quote! {
78 pub struct #residual_ident<#residual_generic>(pub #residual_generic);
79
80 impl #impl_generics ::core::ops::Try for #enum_ident #ty_generics #where_clause {
81 type Output = #output_ty;
82 type Residual = #residual_ident<#residual_ty>;
83
84 fn from_output(output: Self::Output) -> Self {
85 Self::#success_ident(output)
86 }
87
88 fn branch(self) -> ::core::ops::ControlFlow<Self::Residual, Self::Output> {
89 match self {
90 Self::#success_ident(value) => ::core::ops::ControlFlow::Continue(value),
91 Self::#short_ident(value) => ::core::ops::ControlFlow::Break(#residual_ident(value)),
92 }
93 }
94 }
95
96 impl #impl_generics ::core::ops::FromResidual<#residual_ident<#residual_ty>>
97 for #enum_ident #ty_generics
98 #where_clause
99 {
100 fn from_residual(residual: #residual_ident<#residual_ty>) -> Self {
101 Self::#short_ident(residual.0)
102 }
103 }
104
105 #result_impl
106 })
107}
108
109fn has_flag(attrs: &[Attribute], name: &str) -> syn::Result<bool> {
110 let mut found = false;
111 for attr in attrs.iter().filter(|attr| attr.path().is_ident("outcome")) {
112 attr.parse_nested_meta(|meta| {
113 if meta.path.is_ident(name) {
114 found = true;
115 }
116 Ok(())
117 })?;
118 }
119 Ok(found)
120}
121
122fn result_interop(attrs: &[Attribute]) -> syn::Result<Option<Type>> {
123 let mut out = None;
124 for attr in attrs.iter().filter(|attr| attr.path().is_ident("outcome")) {
125 attr.parse_nested_meta(|meta| {
126 if meta.path.is_ident("result_interop") {
127 let value = meta.value()?;
128 out = Some(value.parse()?);
129 }
130 Ok(())
131 })?;
132 }
133 Ok(out)
134}
135
136fn single_field_type(fields: &Fields) -> syn::Result<Type> {
137 match fields {
138 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => fields
139 .unnamed
140 .first()
141 .map(|field| field.ty.clone())
142 .ok_or_else(|| {
143 Error::new_spanned(fields, "variant must have exactly one unnamed field")
144 }),
145 _ => Err(Error::new_spanned(
146 fields,
147 "variant must have exactly one unnamed field",
148 )),
149 }
150}
151
152fn residual_type_generic(ty: &Type) -> Ident {
153 if let Type::Path(path) = ty
154 && let Some(segment) = path.path.segments.last()
155 && let PathArguments::AngleBracketed(args) = &segment.arguments
156 && let Some(GenericArgument::Type(Type::Path(inner))) = args.args.first()
157 && let Some(inner_segment) = inner.path.segments.last()
158 {
159 return inner_segment.ident.clone();
160 }
161 format_ident!("R")
162}