1#![doc = include_str!("../README.md")]
2
3mod tests;
4
5use proc_macro2::TokenStream;
6use proc_macro_error::{abort, SpanRange};
7use quote::quote;
8use std::str::FromStr;
9use strum::{Display, EnumString};
10use syn::fold::Fold;
11use syn::WhereClause;
12use syn::{
13 parse2, parse_quote, parse_str, punctuated::Punctuated, token::Comma, Block, FnArg,
14 GenericArgument, GenericParam, Generics, Ident, ItemFn, Lifetime, Pat, PatIdent, PatType,
15 PathArguments, Signature, Stmt, Type, TypePath, WherePredicate,
16};
17
18pub fn anyinput_core(args: TokenStream, input: TokenStream) -> TokenStream {
19 if !args.is_empty() {
20 abort!(args, "anyinput does not take any arguments.")
21 }
22
23 let old_item_fn = match parse2::<ItemFn>(input) {
25 Ok(syntax_tree) => syntax_tree,
26 Err(error) => return error.to_compile_error(),
27 };
28
29 let new_item_fn = transform_fn(old_item_fn);
30
31 quote!(#new_item_fn)
32}
33
34pub fn anyinput_core_sample(args: TokenStream, input: TokenStream) -> TokenStream {
35 if !args.is_empty() {
36 abort!(args, "anyinput does not take any arguments.")
37 }
38
39 let old_item_fn = match parse2::<ItemFn>(input) {
41 Ok(syntax_tree) => syntax_tree,
42 Err(error) => return error.to_compile_error(),
43 };
44
45 let new_item_fn = transform_fn_sample(old_item_fn);
46
47 quote!(#new_item_fn)
48}
49
50fn transform_fn_sample(_item_fn: ItemFn) -> ItemFn {
51 println!("input code : {}", quote!(#_item_fn));
52 println!("input syntax: {:?}", _item_fn);
53 parse_quote! {
54 fn hello_world() {
55 println!("Hello, world!");
56 }
57 }
58}
59
60fn transform_fn(item_fn: ItemFn) -> ItemFn {
61 let mut suffix_iter = simple_suffix_iter_factory();
62 let delta_fn_arg_new = |fn_arg| DeltaFnArg::new(fn_arg, &mut suffix_iter);
63
64 item_fn
67 .sig
68 .inputs
69 .iter()
70 .map(delta_fn_arg_new)
71 .fold(ItemFnAcc::init(&item_fn), ItemFnAcc::fold)
72 .to_item_fn()
73}
74
75struct ItemFnAcc<'a> {
76 old_fn: &'a ItemFn,
77 fn_args: Punctuated<FnArg, Comma>,
78 generic_params: Punctuated<GenericParam, Comma>,
79 where_predicates: Punctuated<WherePredicate, Comma>,
80 stmts: Vec<Stmt>,
81}
82
83impl ItemFnAcc<'_> {
84 fn init(item_fn: &ItemFn) -> ItemFnAcc {
85 ItemFnAcc {
87 old_fn: item_fn,
88 fn_args: Punctuated::<FnArg, Comma>::new(),
89 generic_params: item_fn.sig.generics.params.clone(),
90 where_predicates: ItemFnAcc::extract_where_predicates(item_fn),
91 stmts: item_fn.block.stmts.clone(),
92 }
93 }
94
95 fn extract_where_predicates(item_fn: &ItemFn) -> Punctuated<WherePredicate, Comma> {
97 if let Some(WhereClause { predicates, .. }) = &item_fn.sig.generics.where_clause {
98 predicates.clone()
99 } else {
100 parse_quote!()
101 }
102 }
103
104 fn fold(mut self, delta: DeltaFnArg) -> Self {
105 self.fn_args.push(delta.fn_arg);
106 self.generic_params.extend(delta.generic_params);
107 self.where_predicates.extend(delta.where_predicates);
108 for (index, element) in delta.stmt.into_iter().enumerate() {
109 self.stmts.insert(index, element);
110 }
111 self
112 }
113
114 fn to_item_fn(&self) -> ItemFn {
116 ItemFn {
117 sig: Signature {
118 generics: self.to_generics(),
119 inputs: self.fn_args.clone(),
120 ..self.old_fn.sig.clone()
121 },
122 block: Box::new(Block {
123 stmts: self.stmts.clone(),
124 ..*self.old_fn.block.clone()
125 }),
126 ..self.old_fn.clone()
127 }
128 }
129
130 fn to_generics(&self) -> Generics {
131 Generics {
132 lt_token: parse_quote!(<),
133 params: self.generic_params.clone(),
134 gt_token: parse_quote!(>),
135 where_clause: self.to_where_clause(),
136 }
137 }
138
139 fn to_where_clause(&self) -> Option<WhereClause> {
140 if self.where_predicates.is_empty() {
141 None
142 } else {
143 Some(WhereClause {
144 where_token: parse_quote!(where),
145 predicates: self.where_predicates.clone(),
146 })
147 }
148 }
149}
150
151fn simple_suffix_iter_factory() -> impl Iterator<Item = String> + 'static {
155 (0usize..).map(|i| format!("{i}"))
156}
157
158#[derive(Debug, Clone, EnumString, Display)]
160#[allow(clippy::enum_variant_names)]
161enum Special {
162 AnyArray,
163 AnyString,
164 AnyPath,
165 AnyIter,
166 AnyNdArray,
167}
168
169impl Special {
170 fn special_to_where_predicate(
171 &self,
172 generic: &TypePath, maybe_sub_type: Option<Type>,
174 maybe_lifetime: Option<Lifetime>,
175 span_range: &SpanRange,
176 ) -> WherePredicate {
177 match &self {
178 Special::AnyString => {
179 if maybe_sub_type.is_some() {
180 abort!(span_range,"AnyString should not have a generic parameter, so 'AnyString', not 'AnyString<_>'.")
181 };
182 if maybe_lifetime.is_some() {
183 abort!(span_range, "AnyString should not have a lifetime.")
184 };
185 parse_quote! {
186 #generic : AsRef<str>
187 }
188 }
189 Special::AnyPath => {
190 if maybe_sub_type.is_some() {
191 abort!(span_range,"AnyPath should not have a generic parameter, so 'AnyPath', not 'AnyPath<_>'.")
192 };
193 if maybe_lifetime.is_some() {
194 abort!(span_range, "AnyPath should not have a lifetime.")
195 };
196 parse_quote! {
197 #generic : AsRef<std::path::Path>
198 }
199 }
200 Special::AnyArray => {
201 let sub_type = match maybe_sub_type {
202 Some(sub_type) => sub_type,
203 None => {
204 abort!(span_range,"AnyArray expects a generic parameter, for example, AnyArray<usize> or AnyArray<AnyString>.")
205 }
206 };
207 if maybe_lifetime.is_some() {
208 abort!(span_range, "AnyArray should not have a lifetime.")
209 };
210 parse_quote! {
211 #generic : AsRef<[#sub_type]>
212 }
213 }
214 Special::AnyIter => {
215 let sub_type = match maybe_sub_type {
216 Some(sub_type) => sub_type,
217 None => {
218 abort!(span_range,"AnyIter expects a generic parameter, for example, AnyIter<usize> or AnyIter<AnyString>.")
219 }
220 };
221 if maybe_lifetime.is_some() {
222 abort!(span_range, "AnyIter should not have a lifetime.")
223 };
224 parse_quote! {
225 #generic : IntoIterator<Item = #sub_type>
226 }
227 }
228 Special::AnyNdArray => {
229 let sub_type = match maybe_sub_type {
230 Some(sub_type) => sub_type,
231 None => {
232 abort!(span_range,"AnyNdArray expects a generic parameter, for example, AnyNdArray<usize> or AnyNdArray<AnyString>.")
233 }
234 };
235 let lifetime =
236 maybe_lifetime.expect("Internal error: AnyNdArray should be given a lifetime.");
237 parse_quote! {
238 #generic: Into<ndarray::ArrayView1<#lifetime, #sub_type>>
239 }
240 }
241 }
242 }
243
244 fn ident_to_stmt(&self, name: &Ident) -> Stmt {
245 match &self {
246 Special::AnyArray | Special::AnyString | Special::AnyPath => {
247 parse_quote! {
248 let #name = #name.as_ref();
249 }
250 }
251 Special::AnyIter => {
252 parse_quote! {
253 let #name = #name.into_iter();
254 }
255 }
256 Special::AnyNdArray => {
257 parse_quote! {
258 let #name = #name.into();
259 }
260 }
261 }
262 }
263
264 fn should_add_lifetime(&self) -> bool {
265 match self {
266 Special::AnyArray | Special::AnyString | Special::AnyPath | Special::AnyIter => false,
267 Special::AnyNdArray => true,
268 }
269 }
270
271 fn maybe_new(type_path: &TypePath, span_range: &SpanRange) -> Option<(Special, Option<Type>)> {
272 if type_path.qself.is_none() {
274 if let Some(segment) = first_and_only(type_path.path.segments.iter()) {
275 if let Ok(special) = Special::from_str(segment.ident.to_string().as_ref()) {
276 let maybe_sub_type =
277 Special::create_maybe_sub_type(&segment.arguments, span_range);
278 return Some((special, maybe_sub_type));
279 }
280 }
281 }
282 None
283 }
284
285 fn create_maybe_sub_type(args: &PathArguments, span_range: &SpanRange) -> Option<Type> {
286 match args {
287 PathArguments::None => None,
288 PathArguments::AngleBracketed(ref args) => {
289 let arg = first_and_only(args.args.iter()).unwrap_or_else(|| {
290 abort!(span_range, "Expected at exactly one generic parameter.")
291 });
292 if let GenericArgument::Type(sub_type2) = arg {
293 Some(sub_type2.clone())
294 } else {
295 abort!(span_range, "Expected generic parameter to be a type.")
296 }
297 }
298 PathArguments::Parenthesized(_) => {
299 abort!(span_range, "Expected <..> generic parameter.")
300 }
301 }
302 }
303
304 fn to_snake_case(&self) -> String {
307 let mut snake_case_string = String::new();
308 for (index, ch) in self.to_string().chars().enumerate() {
309 if index > 0 && ch.is_uppercase() {
310 snake_case_string.push('_');
311 }
312 snake_case_string.push(ch.to_ascii_lowercase());
313 }
314 snake_case_string
315 }
316}
317
318#[derive(Debug)]
319struct DeltaFnArg {
321 fn_arg: FnArg,
322 generic_params: Vec<GenericParam>,
323 where_predicates: Vec<WherePredicate>,
324 stmt: Option<Stmt>,
325}
326
327impl DeltaFnArg {
328 fn new(fn_arg: &FnArg, suffix_iter: &mut impl Iterator<Item = String>) -> DeltaFnArg {
330 if let Some((pat_ident, pat_type)) = DeltaFnArg::is_normal_fn_arg(fn_arg) {
332 DeltaFnArg::replace_any_specials(pat_type.clone(), pat_ident, suffix_iter)
334 } else {
335 DeltaFnArg {
337 fn_arg: fn_arg.clone(),
338 generic_params: vec![],
339 where_predicates: vec![],
340 stmt: None,
341 }
342 }
343 }
344
345 fn is_normal_fn_arg(fn_arg: &FnArg) -> Option<(&PatIdent, &PatType)> {
347 if let FnArg::Typed(pat_type) = fn_arg {
348 if let Pat::Ident(pat_ident) = &*pat_type.pat {
349 if let Type::Path(_) = &*pat_type.ty {
350 return Some((pat_ident, pat_type));
351 }
352 }
353 }
354 None
355 }
356
357 #[allow(clippy::ptr_arg)]
362 fn replace_any_specials(
363 old_pat_type: PatType,
364 pat_ident: &PatIdent,
365 suffix_iter: &mut impl Iterator<Item = String>,
366 ) -> DeltaFnArg {
367 let mut delta_pat_type = DeltaPatType::new(suffix_iter);
368 let new_pat_type = delta_pat_type.fold_pat_type(old_pat_type);
369
370 DeltaFnArg {
372 fn_arg: FnArg::Typed(new_pat_type),
373 stmt: delta_pat_type.generate_any_stmt(pat_ident),
374 generic_params: delta_pat_type.generic_params,
375 where_predicates: delta_pat_type.where_predicates,
376 }
377 }
378}
379
380struct DeltaPatType<'a> {
381 generic_params: Vec<GenericParam>,
382 where_predicates: Vec<WherePredicate>,
383 suffix_iter: &'a mut dyn Iterator<Item = String>,
384 last_special: Option<Special>,
385}
386
387impl Fold for DeltaPatType<'_> {
388 fn fold_type_path(&mut self, type_path_old: TypePath) -> TypePath {
389 let span_range = SpanRange::from_tokens(&type_path_old); let type_path_middle = syn::fold::fold_type_path(self, type_path_old);
393
394 if let Some((special, maybe_sub_types)) = Special::maybe_new(&type_path_middle, &span_range)
396 {
397 self.last_special = Some(special.clone()); self.create_and_define_generic(special, maybe_sub_types, &span_range)
399 } else {
400 self.last_special = None;
401 type_path_middle
402 }
403 }
404}
405
406impl<'a> DeltaPatType<'a> {
407 fn new(suffix_iter: &'a mut dyn Iterator<Item = String>) -> Self {
408 DeltaPatType {
409 generic_params: vec![],
410 where_predicates: vec![],
411 suffix_iter,
412 last_special: None,
413 }
414 }
415
416 fn generate_any_stmt(&self, pat_ident: &PatIdent) -> Option<Stmt> {
420 if let Some(special) = &self.last_special {
421 let stmt = special.ident_to_stmt(&pat_ident.ident);
422 Some(stmt)
423 } else {
424 None
425 }
426 }
427
428 fn create_and_define_generic(
430 &mut self,
431 special: Special,
432 maybe_sub_type: Option<Type>,
433 span_range: &SpanRange,
434 ) -> TypePath {
435 let generic = self.create_generic(&special); let maybe_lifetime = self.create_maybe_lifetime(&special);
437 let where_predicate = special.special_to_where_predicate(
438 &generic,
439 maybe_sub_type,
440 maybe_lifetime,
441 span_range,
442 );
443 let generic_param: GenericParam = parse_quote!(#generic);
444 self.generic_params.push(generic_param);
445 self.where_predicates.push(where_predicate);
446 generic
447 }
448
449 fn create_maybe_lifetime(&mut self, special: &Special) -> Option<Lifetime> {
451 if special.should_add_lifetime() {
452 let lifetime = self.create_lifetime(special);
453 let generic_param: GenericParam = parse_quote!(#lifetime);
454 self.generic_params.push(generic_param);
455
456 Some(lifetime)
457 } else {
458 None
459 }
460 }
461
462 fn create_generic(&mut self, special: &Special) -> TypePath {
464 let suffix = self.create_suffix();
465 let generic_name = format!("{}{}", &special, suffix);
466 parse_str(&generic_name).expect("Internal error: failed to parse generic name")
467 }
468
469 fn create_lifetime(&mut self, special: &Special) -> Lifetime {
471 let lifetime_name = format!("'{}{}", special.to_snake_case(), self.create_suffix());
472 parse_str(&lifetime_name).expect("Internal error: failed to parse lifetime name")
473 }
474
475 fn create_suffix(&mut self) -> String {
477 self.suffix_iter
478 .next()
479 .expect("Internal error: ran out of generic suffixes")
480 }
481}
482
483fn first_and_only<T, I: Iterator<Item = T>>(mut iter: I) -> Option<T> {
485 let first = iter.next()?;
486 if iter.next().is_some() {
487 None
488 } else {
489 Some(first)
490 }
491}
492
493