1use std::collections::HashMap;
2use std::fmt::Write as _;
3
4use quote::ToTokens;
5use syn::__private::Span;
6use syn::{Data, GenericParam, Lifetime, LifetimeDef, Token};
7
8use super::*;
9use crate::parse_attributes::{parse_attrs, VariantInfo};
10
11pub fn fetch_name_with_generic_params(ast: &DeriveInput) -> (String, Vec<String>) {
24 let mut param_string = String::new();
25 let mut lifetimes = vec![];
26 for param in ast.generics.params.iter() {
27 let next = match param {
28 syn::GenericParam::Type(type_) => type_.ident.to_token_stream(),
29 syn::GenericParam::Lifetime(life_def) => {
30 let lifetime = life_def.lifetime.to_token_stream();
31 lifetimes.push(lifetime.to_string());
32 lifetime
33 }
34 syn::GenericParam::Const(constant) => constant.ident.to_token_stream(),
35 };
36 _ = write!(param_string, "{},", next);
37 }
38 param_string.pop();
39 if !param_string.is_empty() {
40 (format!("{}<{}>", ast.ident, param_string), lifetimes)
41 } else {
42 (ast.ident.to_string(), lifetimes)
43 }
44}
45
46pub struct ImplGenerics {
49 pub impl_generics: String,
52 pub impl_generics_ref: String,
56 pub where_clause: String,
59}
60
61pub fn fetch_impl_generics(ast: &DeriveInput, lifetime: &str, bounds: &[String]) -> ImplGenerics {
88 let mut generics = ast.generics.clone();
89 let mut generics_ref = generics.clone();
90 generics_ref
91 .params
92 .push(GenericParam::Lifetime(bound_lifetime(lifetime, bounds)));
93
94 let where_clause = generics
95 .where_clause
96 .take()
97 .map(|w| w.to_token_stream().to_string());
98 ImplGenerics {
99 impl_generics: generics.to_token_stream().to_string(),
100 impl_generics_ref: generics_ref.to_token_stream().to_string(),
101 where_clause: where_clause.unwrap_or_default(),
102 }
103}
104
105pub fn bound_lifetime(lifetime: &str, bounds: &[String]) -> syn::LifetimeDef {
109 let mut lifetime_def = LifetimeDef::new(Lifetime::new(lifetime, Span::call_site()));
110 lifetime_def.colon_token = if bounds.is_empty() {
111 Some(Token))
112 } else {
113 None
114 };
115 lifetime_def.bounds = bounds
116 .iter()
117 .map(|lifetime| Lifetime::new(lifetime, Span::call_site()))
118 .collect();
119 lifetime_def
120}
121
122pub(crate) fn fetch_fields_from_enum(ast: &mut DeriveInput) -> HashMap<String, VariantInfo> {
133 let derive_globally = parse_attrs(&mut ast.attrs);
134 if let Data::Enum(data) = &mut ast.data {
135 let mut num_fields: usize = 0;
136 let mut types = data
137 .variants
138 .iter_mut()
139 .map(|var| match &var.fields {
140 syn::Fields::Unnamed(field_) => {
141 if field_.unnamed.len() != 1 {
142 panic!(
143 "Can only derive for enums whose types do \
144 not contain multiple fields."
145 );
146 }
147 let var_ty = field_
148 .unnamed
149 .iter()
150 .next()
151 .unwrap()
152 .ty
153 .to_token_stream()
154 .to_string();
155 let var_name = var.ident.to_token_stream().to_string();
156 let var_info = VariantInfo {
157 ty: var_ty,
158 try_from: parse_attrs(&mut var.attrs) || derive_globally,
159 };
160 num_fields += 1;
161 (var_info, var_name)
162 }
163 syn::Fields::Named(_) => {
164 panic!("Can only derive for enums whose types do not have named fields.")
165 }
166 syn::Fields::Unit => {
167 panic!("Can only derive for enums who don't contain unit types as variants.")
168 }
169 })
170 .collect::<HashMap<VariantInfo, String>>();
171 let types: HashMap<String, VariantInfo> = types.drain().map(|(k, v)| (v, k)).collect();
172 if types.keys().len() != num_fields {
173 panic!("Cannot derive for enums with more than one field with the same type.")
174 }
175 types
176 } else {
177 panic!("Can only derive for enums.")
178 }
179}
180
181pub(crate) fn create_marker_enums(name: &str, types: &HashMap<String, VariantInfo>) -> String {
187 let mut piece = format!(
188 "#[allow(non_snake_case)]\n mod enum___conversion___{}",
189 name
190 );
191 piece.push_str("{ ");
192 for field in types.keys() {
193 _ = write!(piece, "pub(crate) enum {}{{}}", field);
194 }
195 piece.push('}');
196 piece
197}
198
199pub fn get_marker(name: &str, field: &str) -> String {
202 format!("enum___conversion___{}::{}", name, field)
203}
204
205#[cfg(test)]
206mod test_parsers {
207
208 use super::*;
209
210 const ENUM: &str = r#"
211 enum Enum<'a, 'b, T, U: Debug>
212 where T: Into<U>, U: 'a
213 {
214 #[help]
215 Array([u8; 20]),
216 BareFn(fn(&'a usize) -> bool),
217 Macro(typey!()),
218 Path(<Vec<&'a mut T> as IntoIterator>::Item),
219 Ptr(*const u8),
220 Tuple((&'b i64, bool)),
221 Slice([u8]),
222 Trait(Box<&dyn Into<U>>),
223 }
224 "#;
225
226 #[test]
230 fn test_parse_fields_and_types() {
231 let mut ast: DeriveInput = syn::parse_str(ENUM).expect("Test failed.");
232 let fields = fetch_fields_from_enum(&mut ast);
233 let expected: HashMap<String, VariantInfo> = HashMap::from([
234 ("Array".to_string(), "[u8 ; 20]".into()),
235 ("BareFn".to_string(), "fn (& 'a usize) -> bool".into()),
236 ("Macro".to_string(), "typey ! ()".into()),
237 (
238 "Path".to_string(),
239 "< Vec < & 'a mut T > as IntoIterator > :: Item".into(),
240 ),
241 ("Ptr".to_string(), "* const u8".into()),
242 ("Slice".to_string(), "[u8]".into()),
243 ("Trait".to_string(), "Box < & dyn Into < U > >".into()),
244 ("Tuple".to_string(), "(& 'b i64 , bool)".into()),
245 ]);
246 assert_eq!(expected, fields);
247 }
248
249 #[test]
250 fn test_global_try_from_config() {
251 let mut ast: DeriveInput = syn::parse_str(
252 r#"
253 #[DeriveTryFrom]
254 enum Enum {
255 F1(i64),
256 F2(bool),
257 }
258 "#,
259 )
260 .expect("Test failed");
261 let fields = fetch_fields_from_enum(&mut ast);
262 let expected: HashMap<String, VariantInfo> = HashMap::from([
263 (
264 "F1".to_string(),
265 VariantInfo {
266 ty: "i64".to_string(),
267 try_from: true,
268 },
269 ),
270 (
271 "F2".to_string(),
272 VariantInfo {
273 ty: "bool".to_string(),
274 try_from: true,
275 },
276 ),
277 ]);
278 assert_eq!(fields, expected);
279 }
280
281 #[test]
282 fn test_try_from_local_config() {
283 let mut ast: DeriveInput = syn::parse_str(
284 r#"
285 enum Enum {
286 F1(i64),
287 #[DeriveTryFrom]
288 F2(bool),
289 }
290 "#,
291 )
292 .expect("Test failed");
293 let fields = fetch_fields_from_enum(&mut ast);
294 let expected: HashMap<String, VariantInfo> = HashMap::from([
295 ("F1".to_string(), "i64".into()),
296 (
297 "F2".to_string(),
298 VariantInfo {
299 ty: "bool".to_string(),
300 try_from: true,
301 },
302 ),
303 ]);
304 assert_eq!(fields, expected);
305 }
306
307 #[test]
308 fn test_generics_and_bounds() {
309 let ast: DeriveInput = syn::parse_str(ENUM).expect("Test failed.");
310 let (_, lifetimes) = fetch_name_with_generic_params(&ast);
311 let ImplGenerics {
312 impl_generics,
313 impl_generics_ref,
314 where_clause,
315 } = fetch_impl_generics(&ast, ENUM_CONV_LIFETIME, &lifetimes);
316 assert_eq!(impl_generics, "< 'a , 'b , T , U : Debug >");
317 assert_eq!(
318 impl_generics_ref,
319 "< 'a , 'b , 'enum_conv : 'a + 'b , T , U : Debug , >"
320 );
321 assert_eq!(where_clause, "where T : Into < U > , U : 'a");
322 }
323
324 #[test]
325 fn test_get_name_with_generics() {
326 let ast: DeriveInput = syn::parse_str(ENUM).expect("Test failed.");
327 let (name, lifetimes) = fetch_name_with_generic_params(&ast);
328 assert_eq!(name, "Enum<'a,'b,T,U>");
329 assert_eq!(lifetimes, vec![String::from("'a"), String::from("'b")]);
330 }
331
332 #[test]
333 #[should_panic(expected = "Can only derive for enums.")]
334 fn test_panic_on_struct() {
335 let mut ast = syn::parse_str("pub struct Struct;").expect("Test failed");
336 _ = fetch_fields_from_enum(&mut ast);
337 }
338
339 #[test]
340 #[should_panic(expected = "Can only derive for enums whose types do not have named fields.")]
341 fn test_panic_on_field_with_named_types() {
342 let mut ast = syn::parse_str(
343 r#"
344 enum Enum {
345 F{a: i64},
346 }
347 "#,
348 )
349 .expect("Test failed");
350 _ = fetch_fields_from_enum(&mut ast);
351 }
352
353 #[test]
354 #[should_panic(
355 expected = "Cannot derive for enums with more than one field with the same type."
356 )]
357 fn test_multiple_fields_same_type() {
358 let mut ast = syn::parse_str(
359 r#"
360 enum Enum {
361 F1(u64),
362 F2(u64),
363 }
364 "#,
365 )
366 .expect("Test failed");
367 _ = fetch_fields_from_enum(&mut ast);
368 }
369
370 #[test]
371 #[should_panic(
372 expected = "Can only derive for enums whose types do not contain multiple fields."
373 )]
374 fn test_multiple_types_in_field() {
375 let mut ast = syn::parse_str(
376 r#"
377 enum Enum {
378 Field(i64, bool),
379 }
380 "#,
381 )
382 .expect("Test failed");
383 _ = fetch_fields_from_enum(&mut ast);
384 }
385
386 #[test]
387 #[should_panic(
388 expected = "Can only derive for enums who don't contain unit types as variants."
389 )]
390 fn test_unit_type() {
391 let mut ast = syn::parse_str(
392 r#"
393 enum Enum {
394 Some(bool),
395 None,
396 }
397 "#,
398 )
399 .expect("Test failed");
400 _ = fetch_fields_from_enum(&mut ast);
401 }
402
403 #[test]
405 fn test_harmless() {
406 let mut ast = syn::parse_str(r#"enum Enum{ }"#).expect("Test failed");
407 let fields = fetch_fields_from_enum(&mut ast);
408 assert!(fields.is_empty())
409 }
410
411 #[test]
412 fn test_create_marker_structs() {
413 let mut ast = syn::parse_str(
414 r#"
415 enum Enum {
416 F1(u64)
417 }
418 "#,
419 )
420 .expect("Test failed.");
421 let fields = fetch_fields_from_enum(&mut ast);
422 let output = create_marker_enums(&ast.ident.to_string(), &fields);
423 assert_eq!(
424 output,
425 "#[allow(non_snake_case)]\n mod enum___conversion___Enum{ pub(crate) enum F1{}}"
426 );
427 }
428}