errore_impl/
valid.rs

1use std::collections::BTreeSet as Set;
2
3use quote::ToTokens;
4use syn::{Error, GenericArgument, Member, PathArguments, Result, Type};
5
6use crate::ast::{DeriveType, Enum, Field, Input, Struct, Variant};
7use crate::attr::Attrs;
8
9impl Input<'_> {
10    pub(crate) fn validate(&self) -> Result<()> {
11        match self {
12            Input::Struct(input) => input.validate(),
13            Input::Enum(input) => input.validate(),
14        }
15    }
16}
17
18impl Struct<'_> {
19    fn validate(&self) -> Result<()> {
20        check_non_field_attrs(&self.attrs)?;
21        if let Some(transparent) = self.attrs.transparent {
22            if self.fields.len() != 1 {
23                return Err(Error::new_spanned(
24                    transparent.original,
25                    "#[error(transparent)] requires exactly one field",
26                ));
27            }
28            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
29                return Err(Error::new_spanned(
30                    source,
31                    "transparent error struct can't contain #[source]",
32                ));
33            }
34        }
35        check_field_attrs(&self.fields)?;
36        for field in &self.fields {
37            field.validate()?;
38        }
39        Ok(())
40    }
41}
42
43impl Enum<'_> {
44    fn validate_as_error(&self) -> Result<()> {
45        let has_display = self.has_display();
46
47        for variant in &self.variants {
48            variant.validate()?;
49            if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
50            {
51                return Err(Error::new_spanned(
52                    variant.original,
53                    "missing #[error(\"...\")] display attribute",
54                ));
55            }
56        }
57
58        let mut from_types = Set::new();
59        for variant in &self.variants {
60            if let Some(from_field) = variant.from_field() {
61                let repr = from_field.ty.to_token_stream().to_string();
62                if !from_types.insert(repr) {
63                    return Err(Error::new_spanned(
64                        from_field.original,
65                        "cannot derive From because another variant has the same source type",
66                    ));
67                }
68            }
69        }
70
71        Ok(())
72    }
73
74    fn validate_as_display(&self) -> Result<()> {
75        // at the moment no validation is required, since if no #[display(\"...\")] attribute is supplied,
76        // it will fallback to the default display implementation
77        Ok(())
78    }
79
80    fn validate(&self) -> Result<()> {
81        check_non_field_attrs(&self.attrs)?;
82        match self.derive {
83            DeriveType::Error => self.validate_as_error(),
84            DeriveType::Display => self.validate_as_display(),
85        }
86    }
87}
88
89impl Variant<'_> {
90    fn validate(&self) -> Result<()> {
91        check_non_field_attrs(&self.attrs)?;
92        if self.attrs.transparent.is_some() {
93            if self.fields.len() != 1 {
94                return Err(Error::new_spanned(
95                    self.original,
96                    "#[error(transparent)] requires exactly one field",
97                ));
98            }
99            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
100                return Err(Error::new_spanned(
101                    source,
102                    "transparent variant can't contain #[source]",
103                ));
104            }
105        }
106        check_field_attrs(&self.fields)?;
107        for field in &self.fields {
108            field.validate()?;
109        }
110        Ok(())
111    }
112}
113
114impl Field<'_> {
115    fn validate(&self) -> Result<()> {
116        if let Some(display) = &self.attrs.display {
117            return Err(Error::new_spanned(
118                display.original,
119                "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
120            ));
121        }
122        Ok(())
123    }
124}
125
126fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
127    if let Some(from) = &attrs.from {
128        return Err(Error::new_spanned(
129            from,
130            "not expected here; the #[from] attribute belongs on a specific field",
131        ));
132    }
133    if let Some(source) = &attrs.source {
134        return Err(Error::new_spanned(
135            source,
136            "not expected here; the #[source] attribute belongs on a specific field",
137        ));
138    }
139    if let Some(display) = &attrs.display {
140        if attrs.transparent.is_some() {
141            return Err(Error::new_spanned(
142                display.original,
143                "cannot have both #[error(transparent)] and a display attribute",
144            ));
145        }
146    }
147    Ok(())
148}
149
150fn check_field_attrs(fields: &[Field]) -> Result<()> {
151    let mut from_field = None;
152    let mut source_field = None;
153    for field in fields {
154        if let Some(from) = field.attrs.from {
155            if from_field.is_some() {
156                return Err(Error::new_spanned(from, "duplicate #[from] attribute"));
157            }
158            from_field = Some(field);
159        }
160        if let Some(source) = field.attrs.source {
161            if source_field.is_some() {
162                return Err(Error::new_spanned(source, "duplicate #[source] attribute"));
163            }
164            source_field = Some(field);
165        }
166        if let Some(transparent) = field.attrs.transparent {
167            return Err(Error::new_spanned(
168                transparent.original,
169                "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
170            ));
171        }
172    }
173    if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
174        if !same_member(from_field, source_field) {
175            return Err(Error::new_spanned(
176                from_field.attrs.from,
177                "#[from] is only supported on the source field, not any other field",
178            ));
179        }
180    }
181    if let Some(from_field) = from_field {
182        let max_expected_fields = 1;
183        if fields.len() > max_expected_fields {
184            return Err(Error::new_spanned(
185                from_field.attrs.from,
186                "deriving From requires no fields other than source",
187            ));
188        }
189    }
190    if let Some(source_field) = source_field.or(from_field) {
191        if contains_non_static_lifetime(source_field.ty) {
192            return Err(Error::new_spanned(
193                &source_field.original.ty,
194                "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
195            ));
196        }
197    }
198    Ok(())
199}
200
201fn same_member(one: &Field, two: &Field) -> bool {
202    match (&one.member, &two.member) {
203        (Member::Named(one), Member::Named(two)) => one == two,
204        (Member::Unnamed(one), Member::Unnamed(two)) => one.index == two.index,
205        _ => unreachable!(),
206    }
207}
208
209fn contains_non_static_lifetime(ty: &Type) -> bool {
210    match ty {
211        Type::Path(ty) => {
212            let bracketed = match &ty.path.segments.last().unwrap().arguments {
213                PathArguments::AngleBracketed(bracketed) => bracketed,
214                _ => return false,
215            };
216            for arg in &bracketed.args {
217                match arg {
218                    GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
219                    GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
220                        return true
221                    }
222                    _ => {}
223                }
224            }
225            false
226        }
227        Type::Reference(ty) => ty
228            .lifetime
229            .as_ref()
230            .map_or(false, |lifetime| lifetime.ident != "static"),
231        _ => false, // maybe implement later if there are common other cases
232    }
233}