1use std::collections::BTreeSet as Set;
2
3use quote::ToTokens;
4use syn::{Error, GenericArgument, Member, PathArguments, Result, Type};
5
6use crate::ast::{DeriveType, Enum, Field, Input, Struct, Variant};
7use crate::attr::Attrs;
8
9impl Input<'_> {
10 pub(crate) fn validate(&self) -> Result<()> {
11 match self {
12 Input::Struct(input) => input.validate(),
13 Input::Enum(input) => input.validate(),
14 }
15 }
16}
17
18impl Struct<'_> {
19 fn validate(&self) -> Result<()> {
20 check_non_field_attrs(&self.attrs)?;
21 if let Some(transparent) = self.attrs.transparent {
22 if self.fields.len() != 1 {
23 return Err(Error::new_spanned(
24 transparent.original,
25 "#[error(transparent)] requires exactly one field",
26 ));
27 }
28 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
29 return Err(Error::new_spanned(
30 source,
31 "transparent error struct can't contain #[source]",
32 ));
33 }
34 }
35 check_field_attrs(&self.fields)?;
36 for field in &self.fields {
37 field.validate()?;
38 }
39 Ok(())
40 }
41}
42
43impl Enum<'_> {
44 fn validate_as_error(&self) -> Result<()> {
45 let has_display = self.has_display();
46
47 for variant in &self.variants {
48 variant.validate()?;
49 if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
50 {
51 return Err(Error::new_spanned(
52 variant.original,
53 "missing #[error(\"...\")] display attribute",
54 ));
55 }
56 }
57
58 let mut from_types = Set::new();
59 for variant in &self.variants {
60 if let Some(from_field) = variant.from_field() {
61 let repr = from_field.ty.to_token_stream().to_string();
62 if !from_types.insert(repr) {
63 return Err(Error::new_spanned(
64 from_field.original,
65 "cannot derive From because another variant has the same source type",
66 ));
67 }
68 }
69 }
70
71 Ok(())
72 }
73
74 fn validate_as_display(&self) -> Result<()> {
75 Ok(())
78 }
79
80 fn validate(&self) -> Result<()> {
81 check_non_field_attrs(&self.attrs)?;
82 match self.derive {
83 DeriveType::Error => self.validate_as_error(),
84 DeriveType::Display => self.validate_as_display(),
85 }
86 }
87}
88
89impl Variant<'_> {
90 fn validate(&self) -> Result<()> {
91 check_non_field_attrs(&self.attrs)?;
92 if self.attrs.transparent.is_some() {
93 if self.fields.len() != 1 {
94 return Err(Error::new_spanned(
95 self.original,
96 "#[error(transparent)] requires exactly one field",
97 ));
98 }
99 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
100 return Err(Error::new_spanned(
101 source,
102 "transparent variant can't contain #[source]",
103 ));
104 }
105 }
106 check_field_attrs(&self.fields)?;
107 for field in &self.fields {
108 field.validate()?;
109 }
110 Ok(())
111 }
112}
113
114impl Field<'_> {
115 fn validate(&self) -> Result<()> {
116 if let Some(display) = &self.attrs.display {
117 return Err(Error::new_spanned(
118 display.original,
119 "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
120 ));
121 }
122 Ok(())
123 }
124}
125
126fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
127 if let Some(from) = &attrs.from {
128 return Err(Error::new_spanned(
129 from,
130 "not expected here; the #[from] attribute belongs on a specific field",
131 ));
132 }
133 if let Some(source) = &attrs.source {
134 return Err(Error::new_spanned(
135 source,
136 "not expected here; the #[source] attribute belongs on a specific field",
137 ));
138 }
139 if let Some(display) = &attrs.display {
140 if attrs.transparent.is_some() {
141 return Err(Error::new_spanned(
142 display.original,
143 "cannot have both #[error(transparent)] and a display attribute",
144 ));
145 }
146 }
147 Ok(())
148}
149
150fn check_field_attrs(fields: &[Field]) -> Result<()> {
151 let mut from_field = None;
152 let mut source_field = None;
153 for field in fields {
154 if let Some(from) = field.attrs.from {
155 if from_field.is_some() {
156 return Err(Error::new_spanned(from, "duplicate #[from] attribute"));
157 }
158 from_field = Some(field);
159 }
160 if let Some(source) = field.attrs.source {
161 if source_field.is_some() {
162 return Err(Error::new_spanned(source, "duplicate #[source] attribute"));
163 }
164 source_field = Some(field);
165 }
166 if let Some(transparent) = field.attrs.transparent {
167 return Err(Error::new_spanned(
168 transparent.original,
169 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
170 ));
171 }
172 }
173 if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
174 if !same_member(from_field, source_field) {
175 return Err(Error::new_spanned(
176 from_field.attrs.from,
177 "#[from] is only supported on the source field, not any other field",
178 ));
179 }
180 }
181 if let Some(from_field) = from_field {
182 let max_expected_fields = 1;
183 if fields.len() > max_expected_fields {
184 return Err(Error::new_spanned(
185 from_field.attrs.from,
186 "deriving From requires no fields other than source",
187 ));
188 }
189 }
190 if let Some(source_field) = source_field.or(from_field) {
191 if contains_non_static_lifetime(source_field.ty) {
192 return Err(Error::new_spanned(
193 &source_field.original.ty,
194 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
195 ));
196 }
197 }
198 Ok(())
199}
200
201fn same_member(one: &Field, two: &Field) -> bool {
202 match (&one.member, &two.member) {
203 (Member::Named(one), Member::Named(two)) => one == two,
204 (Member::Unnamed(one), Member::Unnamed(two)) => one.index == two.index,
205 _ => unreachable!(),
206 }
207}
208
209fn contains_non_static_lifetime(ty: &Type) -> bool {
210 match ty {
211 Type::Path(ty) => {
212 let bracketed = match &ty.path.segments.last().unwrap().arguments {
213 PathArguments::AngleBracketed(bracketed) => bracketed,
214 _ => return false,
215 };
216 for arg in &bracketed.args {
217 match arg {
218 GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
219 GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
220 return true
221 }
222 _ => {}
223 }
224 }
225 false
226 }
227 Type::Reference(ty) => ty
228 .lifetime
229 .as_ref()
230 .map_or(false, |lifetime| lifetime.ident != "static"),
231 _ => false, }
233}