double_dyn/
lib.rs

1
2#![crate_name = "double_dyn"]
3
4#![doc = include_str!("../README.md")]
5
6use std::collections::HashMap;
7
8use proc_macro2::token_stream::IntoIter as TokenIter;
9use proc_macro2::{*};
10use quote::{quote};
11use heck::AsSnakeCase;
12
13mod parse;
14use crate::parse::*;
15
16/// Emits traits and functions to enable multiple dynamic argument dispatch
17#[proc_macro]
18pub fn double_dyn(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19
20    let output = match double_dyn_internal(input.into()) {
21        Ok(expanded) => expanded,
22        Err(error) => error.into_compile_error(),
23    };
24
25    output.into()
26}
27
28fn double_dyn_internal(input: TokenStream) -> Result<TokenStream, SyntaxError> {
29
30    //==================================================================================================================
31    // PHASE 1: Parse the Macro Invocation
32    //==================================================================================================================
33
34    //Parse the preamble of the invocation to get the trait names and any trait bounds
35    let mut iter = input.into_iter();
36    require_keyword(&mut iter, "type", Span::call_site())?;
37    require_keyword(&mut iter, "A", Span::call_site())?;
38    require_punct(&mut iter, ':', Span::call_site())?;
39    let trait_a_name = require_ident(&mut iter, Span::call_site())?;
40    let mut trait_a_bounds = TokenStream::new();
41    while !if_punct(&iter, ';')? {
42        let token = next_token(&mut iter, Span::call_site())?;
43        trait_a_bounds.extend([token]);
44    }
45    require_punct(&mut iter, ';', Span::call_site())?;
46    require_keyword(&mut iter, "type", Span::call_site())?;
47    require_keyword(&mut iter, "B", Span::call_site())?;
48    require_punct(&mut iter, ':', Span::call_site())?;
49    let trait_b_name = require_ident(&mut iter, Span::call_site())?;
50    let mut trait_b_bounds = TokenStream::new();
51    while !if_punct(&iter, ';')? {
52        let token = next_token(&mut iter, Span::call_site())?;
53        trait_b_bounds.extend([token]);
54    }
55    require_punct(&mut iter, ';', Span::call_site())?;
56
57    //See if both the A and B traits are the same, because that affects several behaviors later on
58    let single_trait = trait_a_name.to_string() == trait_b_name.to_string();
59
60    //The pub qualifiers must match across every function signature
61    let mut pub_qualifiers = TokenStream::new();
62
63    //Parse each function signature
64    let mut first_sig = true;
65    let mut fn_sigs = HashMap::new();
66    loop {
67        let mut temp_iter = iter.clone();
68        match require_fn_signature(&mut temp_iter, true, Span::call_site()) {
69            Ok(sig) => {
70
71                //Check that every arg has an arg name
72                for arg in sig.args.iter() {
73                    if arg.arg_name.is_none() {
74                        return Err(SyntaxError {
75                            message: format!("missing arg name.  anonymous args are not allowed"),
76                            span: arg.arg_type.clone().into_iter().next().unwrap().span(),
77                        });
78                    }
79                }
80
81                //Check for duplicate function signature names
82                if fn_sigs.get(&sig.fn_name.to_string()).is_some() {
83                    return Err(SyntaxError {
84                        message: format!("duplicate functions not allowed"),
85                        span: sig.fn_name.span(),
86                    });    
87                }
88
89                //Check that the pub qualifiers match across every function signature
90                if first_sig {
91                    pub_qualifiers = sig.pub_qualifiers.clone();
92                    first_sig = false;
93                } else {
94                    if tokens_to_string(pub_qualifiers.clone()) != tokens_to_string(sig.pub_qualifiers.clone()) {
95                        return Err(SyntaxError {
96                            message: format!("All functions must have the same visibility (e.g. 'pub')"),
97                            span: sig.fn_name.span(),
98                        });
99                    }
100                }
101
102                //Identify the arg indices that might be A or B
103                let mut possible_a_args = vec![];
104                let mut possible_b_args = vec![];
105                for (i, arg) in sig.args.iter().enumerate() {
106                    let arg_token_iter = arg.arg_type.clone().into_iter();
107                    if if_contains_sequence(&arg_token_iter, &["dyn", &trait_a_name.to_string()])? {
108                        possible_a_args.push(i);
109                    }
110                    if if_contains_sequence(&arg_token_iter, &["dyn", &trait_b_name.to_string()])? {
111                        possible_b_args.push(i);
112                    }
113                }
114
115                //If we didn't identify at least one potential A arg index and one potential B then it's an error
116                if possible_a_args.len() < 1 || possible_b_args.len() < 1 {
117                    return Err(SyntaxError {
118                        message: format!("function must have at least one dyn A and one dyn B argument"),
119                        span: sig.fn_name.span(),
120                    });    
121                }
122
123                //Add our valid sig to the map, and move on
124                fn_sigs.insert(sig.fn_name.to_string(), (sig, possible_a_args, possible_b_args));
125                iter = temp_iter;
126            },
127            Err(err) => {
128                if fn_sigs.len() > 0 {
129                    //See if we're ready to move on to implementations
130                    if if_keyword(&mut iter, "impl")? || if_punct(&mut iter, '#')? {
131                        //NOTE: currently we only have #attributes for impls.  This logic will need to change
132                        // if we end up needing to support attributes for functions
133                        break;
134                    } else {
135                        return Err(err); //We found some other error in the function signature
136                    }
137                } else {
138                    return Err(err); //We need at least one function signature
139                }
140            }
141        }
142    }
143
144    //Parse each type pair impl block
145    let mut pairs_map = HashMap::new();
146    let mut type_a_map = HashMap::new();
147    let mut type_b_map = HashMap::new();
148    loop {
149        let mut impl_fns = HashMap::new();
150
151        // Check for any attributes (specifically #[commutative])
152        let is_commutative = if if_punct(&mut iter, '#')? {
153            require_punct(&mut iter, '#', Span::call_site())?;
154            let attrib_group = require_group(&mut iter, Delimiter::Bracket, Span::call_site(), "expected square brackets")?;
155            let mut attrib_token_iter = attrib_group.stream().into_iter();
156
157            //Only the "commutative" attribute is supported
158            require_keyword(&mut attrib_token_iter, "commutative", attrib_group.span())?;
159
160            //"commutative" is only compative if the A and B traits are the same trait
161            if !single_trait {
162                return Err(SyntaxError {
163                    message: format!("commutative attribute requires matching A and B traits"),
164                    span: attrib_group.span(),
165                });
166            }
167
168            true
169        } else {
170            false
171        };
172
173        // The preamble, e.g. "impl for <TypeA, TypeB>"
174        require_keyword(&mut iter, "impl", Span::call_site())?;
175        require_keyword(&mut iter, "for", Span::call_site())?;
176        let type_pair_group = require_angle_group(&mut iter, Span::call_site(), "expected type pair in angle brackets")?;
177        let mut pair_token_iter = type_pair_group.interior_tokens.into_iter();
178
179        //We support either a type by itself or a list of types in square brackets
180        let type_a_list = require_type_or_type_list(&mut pair_token_iter, type_pair_group.close_bracket.span())?;
181        if !if_punct(&mut pair_token_iter, ',')? { //So the error message is a little better
182            return Err(SyntaxError {
183                message: format!("expected type or type list for 'B'"),
184                span: type_pair_group.close_bracket.span(),
185            });
186        }
187        require_punct(&mut pair_token_iter, ',', type_pair_group.close_bracket.span())?;
188        let type_b_list = require_type_or_type_list(&mut pair_token_iter, type_pair_group.close_bracket.span())?;
189
190        // The block containing the functions
191        let fn_group = require_group(&mut iter, Delimiter::Brace, Span::call_site(), "expected curly braces for fn impls")?;
192        let mut block_token_iter = fn_group.stream().into_iter();
193        while !if_end(&mut block_token_iter)? {
194            let sig = require_fn_signature(&mut block_token_iter, false, fn_group.span())?;
195            let fn_body = require_group(&mut block_token_iter, Delimiter::Brace, fn_group.span(), "expected fn body")?;
196            
197            //Check for duplicate function names
198            if impl_fns.get(&sig.fn_name.to_string()).is_some() {
199                return Err(SyntaxError {
200                    message: format!("duplicate functions not allowed"),
201                    span: sig.fn_name.span(),
202                });
203            }
204
205            //Check that this implementation name matches one of the signatures defined above
206            if let Some((template_sig, possible_a_args, possible_b_args)) = fn_sigs.get_mut(&sig.fn_name.to_string()) {
207
208                //Make sure the argument count matches the function template.  NOTE: You might think this check is unnecessary
209                // because we'd catch incompatible args later on, but we want to be able to rely on the argument list being the
210                // same length when manipulting the args array later on, before emitting the tokens to be compiled.
211                if template_sig.args.len() != sig.args.len() {
212                    return Err(SyntaxError {
213                        message: format!("argument count doesn't match signiture"),
214                        span: sig.fn_name.span(),
215                    });
216                }
217
218                //Make sure we can correlate the arg positions for the A and B types
219                for (i, arg) in sig.args.iter().enumerate() {
220                    let arg_token_iter = arg.arg_type.clone().into_iter();
221
222                    //We're looking for either an "#A" or the concrete A type itself in the case that we only have one possible A type
223                    if !if_contains_sequence(&arg_token_iter, &["#", "A"])? 
224                    && !if_contains_tokens(&arg_token_iter, type_a_list[0].clone().into_iter())? {
225                        //If this arg isn't a candidate for a type_a, make sure it's not in the possible_a_args list
226                        if let Some(idx) = possible_a_args.iter().position(|&el| el == i) {
227                            possible_a_args.remove(idx);
228                        }
229                    }
230                    //Do the same for B args
231                    if !if_contains_sequence(&arg_token_iter, &["#", "B"])? 
232                    && !if_contains_tokens(&arg_token_iter, type_b_list[0].clone().into_iter())? {
233                        //If this arg isn't a candidate for a type_a, make sure it's not in the possible_a_args list
234                        if let Some(idx) = possible_b_args.iter().position(|&el| el == i) {
235                            possible_b_args.remove(idx);
236                        }
237                    }
238                }
239
240                //If we ended up disqualifying every arg then that's a problem
241                if possible_a_args.len() < 1 {
242                    return Err(SyntaxError {
243                        message: format!("can't infer position of A arg when reconciled with fn signature"),
244                        span: sig.fn_name.span(),
245                    });
246                }
247                if possible_b_args.len() < 1 {
248                    return Err(SyntaxError {
249                        message: format!("can't infer position of B arg when reconciled with fn signature"),
250                        span: sig.fn_name.span(),
251                    });
252                }
253
254                impl_fns.insert(sig.fn_name.to_string(), (sig, fn_body));
255            } else {
256                return Err(SyntaxError {
257                    message: format!("matching fn signature not found"),
258                    span: sig.fn_name.span(),
259                });
260            }
261        }
262
263        //Check that every function has been implemented
264        if impl_fns.len() != fn_sigs.len() {
265            return Err(SyntaxError {
266                message: format!("incomplete implementation of declared functions"),
267                span: fn_group.span(),
268            });
269        }
270
271        //Put a pair record in the HashMap for each type_a-type_b pair
272        for type_a in type_a_list.iter() {
273            let type_a_string = format!("{}", AsSnakeCase(tokens_to_string(type_a.clone())));
274
275            for type_b in type_b_list.iter() {
276                let type_b_string = format!("{}", AsSnakeCase(tokens_to_string(type_b.clone())));
277
278                //Go over each fn implementation, and replace the placeholders with the concrete types
279                let mut updated_fns = HashMap::new();
280                for (fn_name, (sig, fn_body)) in impl_fns.iter() {
281
282                    //Go through the args in the function signature and swap out the #A and #B types
283                    let mut new_sig = sig.clone();
284                    for arg in new_sig.args.iter_mut() {
285                        let new_arg_type = replace_type_placeholders(arg.arg_type.clone(), type_a, type_b)?;
286                        arg.arg_type = new_arg_type;
287                    }
288
289                    //Now do the same thing for the function body
290                    let new_fn_body = replace_type_placeholders(fn_body.stream(), type_a, type_b)?;
291
292                    updated_fns.insert(fn_name.clone(), (new_sig, new_fn_body));
293                }
294
295                //Put the pair in the pairs_map
296                pairs_map
297                    .entry(type_a_string.clone())
298                    .and_modify(|type_b_map : &mut HashMap<String, HashMap<String, (FnSignature, TokenStream)>>| {
299                        type_b_map.insert(type_b_string.clone(), updated_fns.clone()); //NOTE: these clones bug me but the compiler probably takes care of them
300                    })
301                    .or_insert({
302                        let mut new_map = HashMap::with_capacity(1);
303                        new_map.insert(type_b_string.clone(), updated_fns);
304                        new_map
305                    });
306
307                //If the pair is_commutative, then put the inverse in the pairs_map as well
308                if is_commutative {
309
310                    //We need to do the #A and #B swap in reverse
311                    let mut updated_fns = HashMap::new();
312                    for (fn_name, (sig, fn_body)) in impl_fns.iter() {
313    
314                        //Go through the args in the function signature and swap out the #A and #B types
315                        let mut new_sig = sig.clone();
316                        for arg in new_sig.args.iter_mut() {
317                            let new_arg_type = replace_type_placeholders(arg.arg_type.clone(), type_b, type_a)?;
318                            arg.arg_type = new_arg_type;
319                        }
320    
321                        //Now do the same thing for the function body
322                        let new_fn_body = replace_type_placeholders(fn_body.stream(), type_b, type_a)?;
323    
324                        updated_fns.insert(fn_name.clone(), (new_sig, new_fn_body));
325                    }
326    
327                    //Put the inverse pair in the pairs_map
328                    pairs_map
329                        .entry(type_b_string.clone())
330                        .and_modify(|type_a_map : &mut HashMap<String, HashMap<String, (FnSignature, TokenStream)>>| {
331                            type_a_map.insert(type_a_string.clone(), updated_fns.clone());
332                        })
333                        .or_insert({
334                            let mut new_map = HashMap::with_capacity(1);
335                            new_map.insert(type_a_string.clone(), updated_fns);
336                            new_map
337                        });
338                }
339
340                //Update the map of all b_types
341                type_b_map.insert(type_b_string, type_b.clone());
342            }
343
344            //Update the map of all a_types
345            type_a_map.insert(type_a_string, type_a.clone());
346        }
347
348        //Any more tokens must be additional impl blocks
349        if if_end(&mut iter)? {
350            break;
351        }
352    }
353
354    //For each function, collapse the possible arg positions (possible_a_args & possible_b_args) into a single arg index
355    for (sig, possible_a_args, possible_b_args) in fn_sigs.values_mut() {
356
357        //If we have the same trait for A and B, then just pick one index for A and the other one for B
358        if single_trait {
359            if let Some(idx) = possible_b_args.iter().position(|&el| el == possible_a_args[0]) {
360                possible_b_args.remove(idx);
361            }
362            if let Some(idx) = possible_a_args.iter().position(|&el| el == possible_b_args[0]) {
363                possible_a_args.remove(idx);
364            }
365        }
366
367        //And if we ended up disqualifying all possible args, that's an error
368        if possible_a_args.len() < 1 || possible_b_args.len() < 1 {
369            return Err(SyntaxError {
370                message: format!("can't infer position of both A and B args"),
371                span: sig.fn_name.span(),
372            });
373        }
374
375        //Now if we have more than one index for for either A or B then the signature is ambiguous so that's an error
376        if possible_a_args.len() > 1 {
377            return Err(SyntaxError {
378                message: format!("ambiguous signature; can't infer position of A arg"),
379                span: sig.fn_name.span(),
380            });
381        }
382        if possible_b_args.len() > 1 {
383            return Err(SyntaxError {
384                message: format!("ambiguous signature; can't infer position of B arg"),
385                span: sig.fn_name.span(),
386            });
387        }
388    }
389    
390    //==================================================================================================================
391    // PHASE 2: Build the Output Tokens
392    //==================================================================================================================
393
394    //If we're only dealing with one trait, then we only have one list of types
395    if single_trait {
396        type_a_map.extend(type_b_map.iter().map(|pair| (pair.0.clone(), pair.1.clone())));
397        type_b_map = type_a_map.clone();
398    }
399    
400    //Transmute all of the function prototypes into methods for the ATrait
401    let mut l1_sig_tokens = TokenStream::new();
402    let mut l1_sigs = HashMap::new();
403    for (fn_name, (sig, possible_a_args, _possible_b_args)) in fn_sigs.iter() {
404
405        //Turns "fn min_max(val: i32, min: &dyn MyTraitA, max: &dyn MyTraitB) -> Result<i32, String>;" into
406        // "fn l1_min_max(&self, val: i32, max: &dyn MyTraitB) -> Result<i32, String>;"
407        let mut new_sig = sig.clone();
408        new_sig.pub_qualifiers = TokenStream::new(); //no visibility qualifiers on trait methods
409        new_sig.fn_name = Ident::new(&format!("l1_{}", fn_name), sig.fn_name.span());
410        new_sig.args.remove(possible_a_args[0]); //Get rid of the arg that'll be replaced by self
411        new_sig.args.insert(0, FnArg{
412            arg_name: None,
413            arg_type: quote! { &self }
414        });
415
416        let sig_tokens = render_fn_signature(new_sig.clone())?;
417        l1_sigs.insert(fn_name.clone(), (new_sig, sig_tokens.clone()));
418        l1_sig_tokens.extend(sig_tokens);
419        l1_sig_tokens.extend(quote! { ; });    
420    }
421
422    //Transmute all of the function prototypes into methods for the BTrait
423    let mut l2_sig_tokens = TokenStream::new();
424    let mut l2_sigs = HashMap::new();
425    for (fn_name, (sig, possible_a_args, possible_b_args)) in fn_sigs.iter() {
426
427        //Create a separate variant of each function for each of the A types
428        for a_type_string in type_a_map.keys() {
429
430            let (new_sig, _old_b_arg) = transmute_to_l2_signature(sig.clone(), a_type_string, &type_a_map, possible_a_args[0], possible_b_args[0])?;
431            let sig_tokens = render_fn_signature(new_sig.clone())?;
432            l2_sigs.insert((fn_name, a_type_string), (new_sig, sig_tokens.clone()));
433            l2_sig_tokens.extend(sig_tokens);
434            l2_sig_tokens.extend(quote! { ; });    
435        }
436    }
437
438    // --1-- Create the definition of the traits
439    let mut result_tokens = if single_trait {
440        quote! {
441            #pub_qualifiers trait #trait_a_name #trait_a_bounds {
442                #l1_sig_tokens
443
444                #l2_sig_tokens
445            }
446        }
447    } else {
448        quote! {
449            #pub_qualifiers trait #trait_a_name #trait_a_bounds {
450                #l1_sig_tokens
451            }
452
453            #pub_qualifiers trait #trait_b_name #trait_b_bounds {
454                #l2_sig_tokens
455            }
456        }
457    };
458
459    // --2-- Emit the L1 trait impls
460    for (a_type_name, a_type) in type_a_map.iter() {
461
462        //If we only have one trait, build up the l2 fn impls, to be included alongside the l1 fn impls
463        // Since the a-types and b-types are the same set of types, we need to provide impls for the whole set so
464        // we're passing the b_type for A and the a_type for B.
465        let l2_impls_single_trait = if single_trait {
466            let mut l2_impls = TokenStream::new();
467            for b_type_name in type_b_map.keys() {
468                let impl_tokens = render_l2_fns_for_pair(b_type_name, a_type_name, &pairs_map, &type_a_map, &fn_sigs, &l2_sigs)?;
469                l2_impls.extend(impl_tokens);
470            }
471            l2_impls
472        } else {
473            TokenStream::new()
474        };
475    
476        //Build up the tokens for the l1 methods, for the "impl TraitA for TypeA"
477        let mut l1_impls = TokenStream::new();
478        for (orig_fn_name, (_l1_sig, l1_sig_tokens)) in l1_sigs.iter() {
479            let (prototype_sig, possible_a_args, possible_b_args) = fn_sigs.get(orig_fn_name).unwrap();
480
481            //Get the name of the B arg, so we can use it to call the l2 function
482            let b_arg_name = prototype_sig.args[possible_b_args[0]].arg_name.clone().unwrap();
483
484            //We'll pass all of the other args to the l2 function
485            let mut other_arg_name_tokens = TokenStream::new();
486            for (i, arg) in prototype_sig.args.iter().enumerate() {
487                if i != possible_a_args[0] && i != possible_b_args[0] {
488                    let arg_name = arg.arg_name.clone().unwrap();
489                    other_arg_name_tokens.extend(quote! {
490                        #arg_name,
491                    });
492                }
493            }
494
495            //Figure out the l2 function name
496            let (l2_sig, _l2_sig_tokens) = l2_sigs.get(&(orig_fn_name, &a_type_name)).unwrap();
497            let l2_fn_name = &l2_sig.fn_name;
498
499            //Compose an l1 function that calls the appropriate l2 function with the right args
500            let l1_impl = quote! {
501                #l1_sig_tokens {
502                    #b_arg_name.#l2_fn_name(#other_arg_name_tokens &self)
503                }
504            };
505
506            l1_impls.extend(l1_impl);
507        }
508
509        let a_trait_impl = quote! {
510            impl #trait_a_name for #a_type {
511                #l1_impls
512                #l2_impls_single_trait
513            }
514        };
515
516        result_tokens.extend(a_trait_impl);
517    }
518
519    // --3-- Emit the L2 trait impls
520    if !single_trait {
521        for (b_type_name, b_type) in type_b_map.iter() {
522
523            let mut l2_impls = TokenStream::new();
524            for a_type_name in type_a_map.keys() {
525                let impl_tokens = render_l2_fns_for_pair(a_type_name, b_type_name, &pairs_map, &type_a_map, &fn_sigs, &l2_sigs)?;
526                l2_impls.extend(impl_tokens);
527            }
528
529            let b_trait_impl = quote! {
530                impl #trait_b_name for #b_type {
531                    #l2_impls
532                }
533            };
534    
535            result_tokens.extend(b_trait_impl);
536        }
537    }
538
539    // --4-- Emit the top-level function(s)
540    for (orig_fn_name, (sig, possible_a_args, _possible_b_args)) in fn_sigs.iter() {
541
542        let sig_tokens = render_fn_signature(sig.clone())?;
543        let (l1_sig, _l1_sig_tokens) = l1_sigs.get(orig_fn_name).unwrap();
544        let l1_fn_name = l1_sig.fn_name.clone();
545
546        //Get the name of the A arg, so we can use it to call the l1 trait method
547        let a_arg_name = sig.args[possible_a_args[0]].arg_name.clone().unwrap();
548
549        //We'll pass all of the other args to the l1 method
550        let mut other_arg_name_tokens = TokenStream::new();
551        for (i, arg) in sig.args.iter().enumerate() {
552            if i != possible_a_args[0] {
553                let arg_name = arg.arg_name.clone().unwrap();
554                other_arg_name_tokens.extend(quote! {
555                    #arg_name,
556                });
557            }
558        }
559
560        let fn_tokens = quote! {
561            #sig_tokens {
562                #a_arg_name.#l1_fn_name(#other_arg_name_tokens)
563            }
564        };
565
566        result_tokens.extend(fn_tokens);
567    }
568
569    Ok(result_tokens.into())
570}
571
572//Parse a type by itself or a list of types in square brackets
573fn require_type_or_type_list(iter: &mut TokenIter, err_span: Span) -> Result<Vec<TokenStream>, SyntaxError> {
574    
575    let mut type_list = vec![];
576    if if_group(iter, Delimiter::Bracket)? {
577        let type_list_group = require_group(iter, Delimiter::Bracket, err_span.clone(), "expected square braces for type array")?;
578        let mut type_tokens_iter = type_list_group.stream().into_iter();
579        loop {
580            type_list.push(require_type(&mut type_tokens_iter, type_list_group.span())?);
581            if if_end(&type_tokens_iter)? {
582                break;
583            } else {
584                require_punct(&mut type_tokens_iter, ',', type_list_group.span())?;
585            }
586        }
587        if type_list.len() < 1 {
588            //return err if we didn't push anything to the array
589            return Err(syntax(TokenTree::Group(type_list_group), "expected at least one type"));
590        }
591    } else {
592        let type_group = require_type(iter, err_span.clone())?;
593        type_list.push(type_group);
594    }
595
596    Ok(type_list)
597}
598
599fn render_l2_fns_for_pair(
600    a_type_name: &String,
601    b_type_name: &String,
602    pairs_map: &HashMap<String, HashMap<String, HashMap<String, (FnSignature, TokenStream)>>>,
603    type_a_map: &HashMap<String, TokenStream>,
604    fn_sigs: &HashMap<String, (FnSignature, Vec<usize>, Vec<usize>)>,
605    l2_sigs: &HashMap<(&String, &String), (FnSignature, TokenStream)>) -> Result<TokenStream, SyntaxError> {
606
607    let mut l2_impls = TokenStream::new();
608
609    let found_pair = if let Some(a_pair_map) = pairs_map.get(a_type_name) {
610        if let Some(pair_fn_map) = a_pair_map.get(b_type_name) {
611
612            //Emit methods with the body from the macro invocation 
613            for (orig_fn_name, (_sig, possible_a_args, possible_b_args)) in fn_sigs.iter() {
614
615                let (pair_fn_sig, pair_fn_body) = pair_fn_map.get(orig_fn_name).unwrap();
616                let (new_sig, old_b_arg) = transmute_to_l2_signature(pair_fn_sig.clone(), a_type_name, type_a_map, possible_a_args[0], possible_b_args[0])?;
617                let sig_tokens = render_fn_signature(new_sig)?;
618                l2_impls.extend(sig_tokens);
619
620                //Emit an assignment, to assign self back to the original argument name
621                let old_b_arg_name = old_b_arg.arg_name.clone().unwrap();
622                let self_assignment_tokens = quote! {
623                    let #old_b_arg_name = self;
624                };
625
626                l2_impls.extend(quote! {
627                    {
628                        #self_assignment_tokens
629
630                        #pair_fn_body
631                    }
632                });
633            }
634
635            true
636        } else {
637            false
638        }
639    } else {
640        false
641    };
642
643    if !found_pair {
644        //Emit methods with an "unimplemented" body
645        for orig_fn_name in fn_sigs.keys() {
646
647            //Get the tokens for the l2 fn signature from the l2_sigs HashMap, and prepend a '_' to the arg names
648            // in order to supress "unused variable" warnings
649            let (l2_sig, _l2_sig_tokens) = l2_sigs.get(&(orig_fn_name, a_type_name)).unwrap();
650            let mut new_sig = l2_sig.clone();
651            for arg in new_sig.args.iter_mut() {
652                if let Some(arg_name) = &mut arg.arg_name {
653                    *arg_name = Ident::new(&format!("_{}", arg_name.to_string()), arg_name.span());
654                }
655            }
656            let new_sig_tokens = render_fn_signature(new_sig)?;
657
658            l2_impls.extend(new_sig_tokens);
659            l2_impls.extend(quote! {
660                {
661                    unimplemented!();
662                }
663            });
664        }
665    }
666
667    Ok(l2_impls)
668}
669
670//Turns "fn min_max(val: i32, min: &dyn MyTraitA, max: &dyn MyTraitB) -> Result<i32, String>;" into
671// "fn l2_min_max_i32(&self, val: i32, min: &i32) -> Result<i32, String>;"
672fn transmute_to_l2_signature(original_sig: FnSignature, a_type_string: &String, type_a_map: &HashMap<String, TokenStream>, a_arg_idx: usize, b_arg_idx: usize) -> Result<(FnSignature, FnArg), SyntaxError> {
673
674    let new_fn_name = Ident::new(&format!("l2_{}_{}", original_sig.fn_name.to_string(), a_type_string), original_sig.fn_name.span());
675    let mut new_sig = original_sig;
676    new_sig.pub_qualifiers = TokenStream::new(); //no visibility qualifiers on trait methods
677    new_sig.fn_name = new_fn_name;
678    //Remove the A and B args because we'll replace them.  But we need to remove them in the right order
679    // because we don't want to screw up the indices
680    let (old_a_arg, old_b_arg) = if a_arg_idx < b_arg_idx {
681        let old_b_arg = new_sig.args.remove(b_arg_idx);
682        let old_a_arg = new_sig.args.remove(a_arg_idx);
683        (old_a_arg, old_b_arg)
684    } else {
685        let old_a_arg = new_sig.args.remove(a_arg_idx);
686        let old_b_arg = new_sig.args.remove(b_arg_idx);
687        (old_a_arg, old_b_arg)
688    };
689    new_sig.args.insert(0, FnArg{
690        arg_name: None,
691        arg_type: quote! { &self }
692    });
693    let type_a_tokens = type_a_map.get(a_type_string).unwrap().clone();
694    new_sig.args.push(FnArg{
695        arg_name: old_a_arg.arg_name,
696        arg_type: quote! { &#type_a_tokens }
697    });
698
699    Ok((new_sig, old_b_arg))
700}
701
702//Replaces "#A" and "#B" placeholders with the tokens representing concrete types
703fn replace_type_placeholders(input_stream: TokenStream, type_a: &TokenStream, type_b: &TokenStream) -> Result<TokenStream, SyntaxError> {
704
705    let mut fn_body_iter = input_stream.into_iter();
706    let mut previous_hash = false;
707    recursive_scan(&mut fn_body_iter, &mut |token, stream| {
708
709        if previous_hash {
710            if let TokenTree::Ident(ident) = token {
711                match ident.to_string().as_str() {
712                    "A" => {
713                        stream.extend([type_a.clone()]);
714                    },
715                    "B" => {
716                        stream.extend([type_b.clone()]);
717                    },
718                    _ => return Err(format!("unknown type macro identifier, #{}", ident.to_string())),
719                };
720                previous_hash = false;
721                return Ok(());
722            } else {
723                return Err(format!("expected special type macro identifier"));
724            }
725        }
726
727        if let TokenTree::Punct(punct) = &token {
728            if punct.as_char() == '#' {
729                previous_hash = true;
730                return Ok(());
731            }
732        }
733
734        stream.extend([token]);
735        Ok(())
736    })
737}
738
739fn tokens_to_string(tokens: TokenStream) -> String {
740    let mut out_string = "".to_string();
741    for token in tokens.into_iter() {
742        match token {
743            TokenTree::Ident(ident) => {
744                out_string.push_str(&ident.to_string());
745            }
746            TokenTree::Literal(literal) => {
747                out_string.push_str(&literal.to_string());
748            }
749            TokenTree::Punct(punct) => {
750                let punct_str = match punct.as_char() {
751                    '&' => "_amp_",
752                    '*' => "_star_",
753                    '.' => "_dot_",
754                    ',' => "_comma_",
755                    '#' => "_hash_",
756                    '@' => "_at_",
757                    '!' => "_bang_",
758                    '$' => "_dollar_",
759                    '%' => "_pct_",
760                    '^' => "_caret_",
761                    '<' => "_lt_",
762                    '>' => "_gt_",
763                    _ => "_punct_"
764                };
765                out_string.push_str(punct_str);
766            }
767            TokenTree::Group(group) => {
768
769                let (open_delim, close_delim) = match group.delimiter() {
770                    Delimiter::Brace => ("_open_curly_", "_close_curly_"),
771                    Delimiter::Parenthesis => ("_open_paren_", "_close_paren_"),
772                    Delimiter::Bracket => ("_open_square_", "_close_square_"),
773                    Delimiter::None => ("_open_none_", "_close_none_"),
774                };
775                let insides = tokens_to_string(group.stream());
776
777                out_string.push_str(open_delim);
778                out_string.push_str(&insides);
779                out_string.push_str(close_delim);
780            }
781
782        }
783    }
784    out_string
785}
786
787fn render_fn_signature(sig: FnSignature) -> Result<TokenStream, SyntaxError> {
788
789    let fn_name = sig.fn_name;
790
791    let generic_tokens = if !sig.generics.is_empty() {
792        let sig_generics = sig.generics;
793        quote! {
794            < #sig_generics >
795        }
796    } else {
797        TokenStream::new()
798    };
799
800    let mut arg_list_tokens = TokenStream::new();
801    for arg in sig.args {
802        if let Some(arg_name_ident) = arg.arg_name {
803            arg_list_tokens.extend([TokenTree::Ident(arg_name_ident), TokenTree::Punct(Punct::new(':', Spacing::Alone))]);
804        }
805
806        arg_list_tokens.extend(arg.arg_type);
807        arg_list_tokens.extend([TokenTree::Punct(Punct::new(',', Spacing::Alone))]);
808    }
809    let result_tokens = if !sig.result.is_empty() {
810        let sig_results = sig.result;
811        quote! {
812            -> #sig_results
813        }
814    } else {
815        TokenStream::new()
816    };
817
818    let pub_qualifiers = sig.pub_qualifiers;
819
820    let sig_tokens = quote! {
821        #pub_qualifiers fn #fn_name #generic_tokens (#arg_list_tokens) #result_tokens
822    };
823
824    Ok(sig_tokens)
825}
826
827
828// // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // //
829// Reference (example of input and the corresponding output, in a form that's easier to read)
830// // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // // //
831
832//Update: Just run `cargo expand`
833
834//=====================================================================================
835// Unit Tests
836//=====================================================================================
837
838//Positive Examples:
839// i32
840// val: i32
841// a: &dyn PrimInt
842// &i32
843// &Vec<&i32>
844// Box<dyn PrimInt>
845// HashMap<String, Box<dyn PrimInt>>
846//
847//Negative Examples:
848// NULL (no tokens)
849// val:
850// HashMap<String
851//
852#[test]
853fn require_fn_arg_test() {
854
855    //Positive Examples:
856    let mut input_tokens_iter = quote! {
857        i32
858    }.into_iter();
859    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
860
861    let mut input_tokens_iter = quote! {
862        val: i32
863    }.into_iter();
864    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
865
866    let mut input_tokens_iter = quote! {
867        a: &dyn PrimInt
868    }.into_iter();
869    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
870
871    let mut input_tokens_iter = quote! {
872        a: &i32
873    }.into_iter();
874    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
875
876    let mut input_tokens_iter = quote! {
877        a: &Vec<&i32>
878    }.into_iter();
879    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
880
881    let mut input_tokens_iter = quote! {
882        a: Box<dyn PrimInt>
883    }.into_iter();
884    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
885
886    let mut input_tokens_iter = quote! {
887        a: HashMap<String, Box<dyn PrimInt>>
888    }.into_iter();
889    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_ok());
890
891    //Negative Examples:
892    let mut input_tokens_iter = quote! {
893        
894    }.into_iter();
895    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_err());
896
897    let mut input_tokens_iter = quote! {
898        val:
899    }.into_iter();
900    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_err());
901
902    let mut input_tokens_iter = quote! {
903        HashMap<String //Ugg.  This bad syntax screws up my text editor's pretty printer, but the compiler is fine
904    }.into_iter();
905    assert!(require_fn_arg(&mut input_tokens_iter, Span::call_site()).is_err());
906
907}
908
909#[test]
910fn require_fn_signature_test() {
911
912    use quote::{quote};
913    use crate::parse::require_fn_signature;     
914
915    //=====================================================================================
916    //Test that I can parse a basic signature
917    let mut input_tokens_iter = quote! {
918        fn min_max(val: i32, min: &i32, max: &i32);
919    }.into_iter();
920
921    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
922
923    assert_eq!(result_signature.fn_name, "min_max");
924
925    //=====================================================================================
926    //Next test that I can parse "pub"
927    let mut input_tokens_iter = quote! {
928        pub fn min_max(val: i32, min: &i32, max: &i32) -> Result<i32, String>;
929    }.into_iter();
930
931    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
932
933    let mut pub_qualifiers_iter = result_signature.pub_qualifiers.into_iter();
934    require_keyword(&mut pub_qualifiers_iter, "pub", Span::call_site()).unwrap();
935    assert!(pub_qualifiers_iter.next().is_none());
936
937    //=====================================================================================
938    //Next test that I can parse pub(crate)
939    let mut input_tokens_iter = quote! {
940        pub(crate) fn min_max(val: i32, min: &i32, max: &i32) -> Result<i32, String>;
941    }.into_iter();
942
943    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
944
945    let mut pub_qualifiers_iter = result_signature.pub_qualifiers.into_iter();
946    require_keyword(&mut pub_qualifiers_iter, "pub", Span::call_site()).unwrap();
947    require_group(&mut pub_qualifiers_iter, Delimiter::Parenthesis, Span::call_site(), "missing '(crate)'").unwrap();
948    assert!(pub_qualifiers_iter.next().is_none());
949
950    //=====================================================================================
951    //Next, test that I can parse some simple generics
952    let mut input_tokens_iter = quote! {
953        fn min_max<A, B>(val: i32, min: &A, max: &B) -> Result<A, String>;
954    }.into_iter();
955
956    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
957
958    let mut generics_iter = result_signature.generics.into_iter();
959    let _ = require_ident(&mut generics_iter, Span::call_site()).unwrap();
960    let _ = require_punct(&mut generics_iter, ',', Span::call_site()).unwrap();
961    let _ = require_ident(&mut generics_iter, Span::call_site()).unwrap();
962    assert!(generics_iter.next().is_none());
963
964    //=====================================================================================
965    //Next, test that I can handle complicated nested generics
966    let mut input_tokens_iter = quote! {
967        fn min_max<A:From<i32>, B>(val: i32, min: &A, max: &B) -> Result<A, String>;
968    }.into_iter();
969
970    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
971
972    let mut generics_iter = result_signature.generics.into_iter();
973    let _ = require_ident(&mut generics_iter, Span::call_site()).unwrap();
974    let _ = require_punct(&mut generics_iter, ':', Span::call_site()).unwrap();
975    let _ = require_ident(&mut generics_iter, Span::call_site()).unwrap();
976    let _ = require_angle_group(&mut generics_iter, Span::call_site(), "expecting angle brackets").unwrap();
977    let _ = require_punct(&mut generics_iter, ',', Span::call_site()).unwrap();
978    let _ = require_ident(&mut generics_iter, Span::call_site()).unwrap();
979    assert!(generics_iter.next().is_none());
980
981    //=====================================================================================
982    //Next, test that I get all the args with names
983    let mut input_tokens_iter = quote! {
984        fn min_max(val: i32, min: &i32, max: &i32);
985    }.into_iter();
986
987    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
988
989    assert_eq!(result_signature.args.len(), 3);
990    assert!(result_signature.args[0].arg_name.is_some());
991    let mut arg2_type_iter = result_signature.args[2].arg_type.clone().into_iter();
992    let _ = require_punct(&mut arg2_type_iter, '&', Span::call_site()).unwrap();
993    let _ = require_ident(&mut arg2_type_iter, Span::call_site()).unwrap();
994    
995    //=====================================================================================
996    //Next, test that I get all the args without names
997    let mut input_tokens_iter = quote! {
998        fn min_max(i32, &i32, &i32);
999    }.into_iter();
1000
1001    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
1002
1003    assert_eq!(result_signature.args.len(), 3);
1004    assert!(result_signature.args[0].arg_name.is_none());
1005    let mut arg2_type_iter = result_signature.args[2].arg_type.clone().into_iter();
1006    let _ = require_punct(&mut arg2_type_iter, '&', Span::call_site()).unwrap();
1007    let _ = require_ident(&mut arg2_type_iter, Span::call_site()).unwrap();
1008
1009    //=====================================================================================
1010    //Next, test that I can handle no arguments
1011    let mut input_tokens_iter = quote! {
1012        fn min_max();
1013    }.into_iter();
1014
1015    require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
1016
1017    //=====================================================================================
1018    //Next, test that I can parse a result
1019    let mut input_tokens_iter = quote! {
1020        fn min_max<A>() -> Result<A, String>;
1021    }.into_iter();
1022
1023    let result_signature = require_fn_signature(&mut input_tokens_iter, true, Span::call_site()).unwrap();
1024
1025    let mut result_iter = result_signature.result.into_iter();
1026    let _ = require_ident(&mut result_iter, Span::call_site()).unwrap();
1027    let _ = require_angle_group(&mut result_iter, Span::call_site(), "expecting angle brackets").unwrap();
1028}