enso_macro_utils/
lib.rs

1//! A number of helper functions meant to be used in the procedural enso-shapely-macros
2//! definitions.
3
4#![warn(missing_docs)]
5#![feature(trait_alias)]
6
7use proc_macro2::TokenStream;
8use proc_macro2::TokenTree;
9use quote::quote;
10use std::iter::FromIterator;
11use syn::visit::Visit;
12use syn::WhereClause;
13use syn::WherePredicate;
14
15
16
17// =====================
18// === Trait Aliases ===
19// =====================
20
21pub trait Str = Into<String> + AsRef<str>;
22
23
24
25// ==========================
26// === Token Stream Utils ===
27// ==========================
28
29/// Maps all the tokens in the stream using a given function.
30pub fn map_tokens<F:Fn(TokenTree) -> TokenTree>
31(input:TokenStream, f:F) -> TokenStream {
32    let ret_iter = input.into_iter().map(f);
33    ret_iter.collect()
34}
35
36/// Rewrites stream replacing each token with a sequence of tokens returned by
37/// the given function. The groups (e.g. token tree within braces) are unpacked,
38/// rewritten and repacked into groups -- the function is applied recursively.
39pub fn rewrite_stream
40<F:Fn(TokenTree) -> TokenStream + Copy>
41(input:TokenStream, f:F) -> TokenStream {
42    let mut ret = TokenStream::new();
43    for token in input.into_iter() {
44        match token {
45            proc_macro2::TokenTree::Group(group) => {
46                let delim  = group.delimiter();
47                let span   = group.span();
48                let rewritten = rewrite_stream(group.stream(), f);
49                let mut new_group = proc_macro2::Group::new(delim,rewritten);
50                new_group.set_span(span);
51                let new_group = vec![TokenTree::from(new_group)];
52                ret.extend(new_group.into_iter())
53            }
54            _ => ret.extend(f(token)),
55        }
56    }
57    ret
58}
59
60
61
62// ===================
63// === Token Utils ===
64// ===================
65
66/// Is the given token an identifier matching to a given string?
67pub fn matching_ident(token:&TokenTree, name:&str) -> bool {
68    match token {
69        TokenTree::Ident(ident) => *ident == name,
70        _                       => false,
71    }
72}
73
74
75
76// ============
77// === Repr ===
78// ============
79
80/// Obtains text representation of given `ToTokens`-compatible input.
81pub fn repr<T: quote::ToTokens>(t:&T) -> String {
82    quote!(#t).to_string()
83}
84
85
86
87// ===================
88// === Field Utils ===
89// ===================
90
91/// Collects all fields, named or not.
92pub fn fields_list(fields:&syn::Fields) -> Vec<&syn::Field> {
93    match fields {
94        syn::Fields::Named  (ref f) => f.named  .iter().collect(),
95        syn::Fields::Unnamed(ref f) => f.unnamed.iter().collect(),
96        syn::Fields::Unit           => Default::default(),
97    }
98}
99
100/// Returns token that refers to the field.
101///
102/// It is the field name for named field and field index for unnamed fields.
103pub fn field_ident_token(field:&syn::Field, index:syn::Index) -> TokenStream {
104    match &field.ident {
105        Some(ident) => quote!(#ident),
106        None        => quote!(#index),
107    }
108}
109
110/// Returns names of the named fields.
111pub fn field_names(fields:&syn::FieldsNamed) -> Vec<&syn::Ident> {
112    fields.named.iter().map(|field| {
113        field.ident.as_ref().expect("Impossible: no name on a named field.")
114    }).collect()
115}
116
117
118
119// ==================
120// === Path Utils ===
121// ==================
122
123/// Checks if a given `Path` consists of a single identifier same as given string.
124pub fn path_matching_ident(path:&syn::Path, str:impl Str) -> bool {
125    path.get_ident().map_or(false, |ident| ident == str.as_ref())
126}
127
128
129
130// ======================
131// === Index Sequence ===
132// ======================
133
134/// For given length, returns a sequence of Literals like `[0,1,2…]`. These are unsuffixed
135/// usize literals, so e.g. can be used to identify the tuple unnamed fields.
136pub fn index_sequence(len:usize) -> Vec<syn::Index> {
137    (0..len).map(syn::Index::from).collect()
138}
139
140/// For given length returns sequence of identifiers like `[field0,field1,…]`.
141pub fn identifier_sequence(len:usize) -> Vec<syn::Ident> {
142    let format_field = |ix| quote::format_ident!("field{}",ix);
143    (0..len).map(format_field).collect()
144}
145
146
147
148// =======================
149// === Type Path Utils ===
150// =======================
151
152/// Obtain list of generic arguments on the path's segment.
153pub fn path_segment_generic_args
154(segment:&syn::PathSegment) -> Vec<&syn::GenericArgument> {
155    match segment.arguments {
156        syn::PathArguments::AngleBracketed(ref args) =>
157            args.args.iter().collect(),
158        _ =>
159            Vec::new(),
160    }
161}
162
163/// Obtain list of generic arguments on the path's last segment.
164///
165/// Empty, if path contains no segments.
166pub fn ty_path_generic_args
167(ty_path:&syn::TypePath) -> Vec<&syn::GenericArgument> {
168    ty_path.path.segments.last().map_or(Vec::new(), path_segment_generic_args)
169}
170
171/// Obtain list of type arguments on the path's last segment.
172pub fn ty_path_type_args
173(ty_path:&syn::TypePath) -> Vec<&syn::Type> {
174    ty_path_generic_args(ty_path).iter().filter_map( |generic_arg| {
175        match generic_arg {
176            syn::GenericArgument::Type(t) => Some(t),
177            _                             => None,
178        }
179    }).collect()
180}
181
182/// Last type argument of the last segment on the type path.
183pub fn last_type_arg(ty_path:&syn::TypePath) -> Option<&syn::GenericArgument> {
184    ty_path_generic_args(ty_path).last().copied()
185}
186
187
188
189// =====================
190// === Collect Types ===
191// =====================
192
193/// Visitor that accumulates all visited `syn::TypePath`.
194#[derive(Default)]
195pub struct TypeGatherer<'ast> {
196    /// Observed types accumulator.
197    pub types: Vec<&'ast syn::TypePath>
198}
199
200impl<'ast> Visit<'ast> for TypeGatherer<'ast> {
201    fn visit_type_path(&mut self, node:&'ast syn::TypePath) {
202        self.types.push(node);
203        syn::visit::visit_type_path(self, node);
204    }
205}
206
207/// All `TypePath`s in the given's `Type` subtree.
208pub fn gather_all_types(node:&syn::Type) -> Vec<&syn::TypePath> {
209    let mut type_gather = TypeGatherer::default();
210    type_gather.visit_type(node);
211    type_gather.types
212}
213
214/// All text representations of `TypePath`s in the given's `Type` subtree.
215pub fn gather_all_type_reprs(node:&syn::Type) -> Vec<String> {
216    gather_all_types(node).iter().map(|t| repr(t)).collect()
217}
218
219
220
221// =======================
222// === Type Dependency ===
223// =======================
224
225/// Naive type equality test by comparing its representation with a string.
226pub fn type_matches_repr(ty:&syn::Type, target_repr:&str) -> bool {
227    repr(ty) == target_repr
228}
229
230/// Naive type equality test by comparing their text representations.
231pub fn type_matches(ty:&syn::Type, target_param:&syn::GenericParam) -> bool {
232    type_matches_repr(ty, &repr(target_param))
233}
234
235/// Does type depends on the given type parameter.
236pub fn type_depends_on(ty:&syn::Type, target_param:&syn::GenericParam) -> bool {
237    let target_param = repr(target_param);
238    let relevant_types = gather_all_types(ty);
239    relevant_types.iter().any(|ty| repr(ty) == target_param)
240}
241
242/// Does enum variant depend on the given type parameter.
243pub fn variant_depends_on
244(var:&syn::Variant, target_param:&syn::GenericParam) -> bool {
245    var.fields.iter().any(|field| type_depends_on(&field.ty, target_param))
246}
247
248
249
250// ===================
251// === WhereClause ===
252// ===================
253
254/// Creates a new where clause from provided sequence of where predicates.
255pub fn new_where_clause(predicates:impl IntoIterator<Item=WherePredicate>) -> WhereClause {
256    let predicates = syn::punctuated::Punctuated::from_iter(predicates);
257    WhereClause {where_token:Default::default(),predicates}
258}
259
260
261
262// =============
263// === Tests ===
264// =============
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use proc_macro2::TokenStream;
270
271    fn parse<T:syn::parse::Parse>(code:&str) -> T {
272        syn::parse_str(code).unwrap()
273    }
274
275    #[test]
276    fn repr_round_trips() {
277        let program = "pub fn repr<T: quote::ToTokens>(t: &T) -> String {}";
278        let tokens = parse::<TokenStream>(program);
279        let quoted_program = repr(&tokens);
280        let tokens2 = parse::<TokenStream>(&quoted_program);
281        // check only second round-trip, first is allowed to break whitespace
282        assert_eq!(repr(&tokens), repr(&tokens2));
283    }
284
285    #[test]
286    fn fields_list_test() {
287        let tuple_like     = "struct Unnamed(i32, String, T);";
288        let proper_struct  = "struct Named{i: i32, s: String, t: T}";
289        let expected_types = vec!["i32", "String", "T"];
290
291        fn assert_field_types(program:&str, expected_types:&[&str]) {
292            let tokens = parse::<syn::ItemStruct>(program);
293            let fields = fields_list(&tokens.fields);
294            let types  = fields.iter().map(|f| repr(&f.ty));
295            assert_eq!(Vec::from_iter(types), expected_types);
296        }
297
298        assert_field_types(tuple_like, &expected_types);
299        assert_field_types(proper_struct, &expected_types);
300    }
301
302    #[test]
303    fn type_dependency() {
304        let param:syn::GenericParam = parse("T");
305        let depends                 = |code| {
306            let ty:syn::Type = parse(code);
307            type_depends_on(&ty, &param)
308        };
309
310        // sample types that depend on `T`
311        let dependents = vec!{
312            "T",
313            "Option<T>",
314            "Pair<T, U>",
315            "Pair<U, T>",
316            "Pair<U, (T,)>",
317            "&T",
318            "&'t mut T",
319        };
320        // sample types that do not depend on `T`
321        let independents = vec!{
322            "Tt",
323            "Option<Tt>",
324            "Pair<Tt, U>",
325            "Pair<U, Tt>",
326            "Pair<U, Tt>",
327            "i32",
328            "&str",
329        };
330        for dependent in dependents {
331            assert!(depends(dependent), "{} must depend on {}"
332                    , repr(&dependent), repr(&param));
333        }
334        for independent in independents {
335            assert!(!depends(independent), "{} must not depend on {}"
336                    , repr(&independent), repr(&param));
337        }
338    }
339
340    #[test]
341    fn collecting_type_path_args() {
342        fn check(expected_type_args:Vec<&str>, ty_path:&str) {
343            let ty_path = parse(ty_path);
344            let args    = super::ty_path_type_args(&ty_path);
345            assert_eq!(expected_type_args.len(), args.len());
346            let zipped  = expected_type_args.iter().zip(args.iter());
347            for (expected,got) in zipped {
348                assert_eq!(expected, &repr(got));
349            }
350        }
351        check(vec!["T"]     , "std::Option<T>");
352        check(vec!["U"]     , "std::Option<U>");
353        check(vec!["A", "B"], "Either<A,B>");
354        assert_eq!(super::last_type_arg(&parse("i32")), None);
355        assert_eq!(repr(&super::last_type_arg(&parse("Foo<C>"))), "C");
356    }
357}