Skip to main content

tryx_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6    Attribute, Data, DeriveInput, Error, Fields, GenericArgument, GenericParam, Ident,
7    PathArguments, Type, parse_macro_input, parse_quote,
8};
9
10#[proc_macro_derive(Outcome, attributes(outcome))]
11pub fn derive_outcome(input: TokenStream) -> TokenStream {
12    expand(parse_macro_input!(input as DeriveInput))
13        .unwrap_or_else(Error::into_compile_error)
14        .into()
15}
16
17fn expand(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
18    let enum_ident = input.ident;
19    let Data::Enum(data) = input.data else {
20        return Err(Error::new_spanned(
21            enum_ident,
22            "Outcome can only be derived for enums",
23        ));
24    };
25
26    let mut success = None;
27    let mut short = None;
28
29    for variant in data.variants {
30        if has_flag(&variant.attrs, "success")? {
31            if success.is_some() {
32                return Err(Error::new_spanned(
33                    variant.ident,
34                    "multiple success variants",
35                ));
36            }
37            success = Some((variant.ident, single_field_type(&variant.fields)?));
38        } else if has_flag(&variant.attrs, "short_circuit")? {
39            if short.is_some() {
40                return Err(Error::new_spanned(
41                    variant.ident,
42                    "multiple short_circuit variants",
43                ));
44            }
45            short = Some((variant.ident, single_field_type(&variant.fields)?));
46        }
47    }
48
49    let (success_ident, output_ty) =
50        success.ok_or_else(|| Error::new_spanned(&enum_ident, "missing #[outcome(success)]"))?;
51    let (short_ident, residual_ty) = short
52        .ok_or_else(|| Error::new_spanned(&enum_ident, "missing #[outcome(short_circuit)]"))?;
53    let result_interop = result_interop(&input.attrs)?;
54    let residual_ident = format_ident!("{enum_ident}Residual");
55    let mut result_generics = input.generics.clone();
56    result_generics
57        .params
58        .push(GenericParam::Type(parse_quote!(E)));
59    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
60    let (result_impl_generics, _, _) = result_generics.split_for_impl();
61    let residual_generic = residual_type_generic(&residual_ty);
62
63    let result_impl = result_interop.map(|_| {
64        quote! {
65            impl #result_impl_generics ::core::ops::FromResidual<#residual_ident<#residual_ty>>
66                for ::core::result::Result<#output_ty, E>
67            where
68                E: ::core::convert::From<#residual_ty>,
69            {
70                fn from_residual(residual: #residual_ident<#residual_ty>) -> Self {
71                    ::core::result::Result::Err(residual.0.into())
72                }
73            }
74        }
75    });
76
77    Ok(quote! {
78        pub struct #residual_ident<#residual_generic>(pub #residual_generic);
79
80        impl #impl_generics ::core::ops::Try for #enum_ident #ty_generics #where_clause {
81            type Output = #output_ty;
82            type Residual = #residual_ident<#residual_ty>;
83
84            fn from_output(output: Self::Output) -> Self {
85                Self::#success_ident(output)
86            }
87
88            fn branch(self) -> ::core::ops::ControlFlow<Self::Residual, Self::Output> {
89                match self {
90                    Self::#success_ident(value) => ::core::ops::ControlFlow::Continue(value),
91                    Self::#short_ident(value) => ::core::ops::ControlFlow::Break(#residual_ident(value)),
92                }
93            }
94        }
95
96        impl #impl_generics ::core::ops::FromResidual<#residual_ident<#residual_ty>>
97            for #enum_ident #ty_generics
98            #where_clause
99        {
100            fn from_residual(residual: #residual_ident<#residual_ty>) -> Self {
101                Self::#short_ident(residual.0)
102            }
103        }
104
105        #result_impl
106    })
107}
108
109fn has_flag(attrs: &[Attribute], name: &str) -> syn::Result<bool> {
110    let mut found = false;
111    for attr in attrs.iter().filter(|attr| attr.path().is_ident("outcome")) {
112        attr.parse_nested_meta(|meta| {
113            if meta.path.is_ident(name) {
114                found = true;
115            }
116            Ok(())
117        })?;
118    }
119    Ok(found)
120}
121
122fn result_interop(attrs: &[Attribute]) -> syn::Result<Option<Type>> {
123    let mut out = None;
124    for attr in attrs.iter().filter(|attr| attr.path().is_ident("outcome")) {
125        attr.parse_nested_meta(|meta| {
126            if meta.path.is_ident("result_interop") {
127                let value = meta.value()?;
128                out = Some(value.parse()?);
129            }
130            Ok(())
131        })?;
132    }
133    Ok(out)
134}
135
136fn single_field_type(fields: &Fields) -> syn::Result<Type> {
137    match fields {
138        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => fields
139            .unnamed
140            .first()
141            .map(|field| field.ty.clone())
142            .ok_or_else(|| {
143                Error::new_spanned(fields, "variant must have exactly one unnamed field")
144            }),
145        _ => Err(Error::new_spanned(
146            fields,
147            "variant must have exactly one unnamed field",
148        )),
149    }
150}
151
152fn residual_type_generic(ty: &Type) -> Ident {
153    if let Type::Path(path) = ty
154        && let Some(segment) = path.path.segments.last()
155        && let PathArguments::AngleBracketed(args) = &segment.arguments
156        && let Some(GenericArgument::Type(Type::Path(inner))) = args.args.first()
157        && let Some(inner_segment) = inner.path.segments.last()
158    {
159        return inner_segment.ident.clone();
160    }
161    format_ident!("R")
162}