rem_controller/
non_local_controller.rs

1use convert_case::{
2    Case,
3    Casing
4};
5use log::debug;
6use proc_macro2::{
7    Ident,
8    Span
9};
10use quote::{
11    quote,
12    ToTokens
13};
14use rem_utils::{
15    format_source,
16    FindCallee
17};
18use syn::visit_mut::VisitMut;
19use syn::{
20    File,
21    parse_str,
22    Block,
23    Expr,
24    ExprCall,
25    ExprMatch,
26    ExprMethodCall,
27    ExprReturn, ExprTry,
28    ImplItemMethod,
29    // Item,
30    ItemFn,
31    // ItemImpl,
32    ItemMod,
33    // ItemTrait,
34    ReturnType,
35    Signature,
36    Stmt,
37    TraitItemMethod,
38    Type
39};
40// use syn::token::Brace;
41
42use crate::error::ControllerError;
43
44const ENUM_NAME: &str = "Ret";
45
46#[derive(Debug, Clone, PartialEq, Hash)]
47pub struct ControllerInput {
48    pub input_code: String, // the original input code
49    pub caller_fn_name: String,
50    pub callee_fn_name: String,
51}
52
53fn make_pascal_case(s: &str) -> String {
54    let result = s.to_case(Case::Pascal);
55    match result.strip_suffix("ExtractThis") {
56        Some(r) => r.to_string(),
57        None => result.to_string(),
58    }
59}
60
61struct CheckCalleeWithinLoopHelper<'a> {
62    callee_fn_name: &'a str,
63    callee_in_loop: bool,
64}
65
66impl VisitMut for CheckCalleeWithinLoopHelper<'_> {
67    fn visit_expr_call_mut(&mut self, i: &mut ExprCall) {
68        let id = i.func.as_ref().into_token_stream().to_string();
69        match id.contains(self.callee_fn_name) {
70            true => self.callee_in_loop = true,
71            false => syn::visit_mut::visit_expr_call_mut(self, i),
72        }
73    }
74
75    fn visit_expr_method_call_mut(&mut self, i: &mut ExprMethodCall) {
76        let callee = i.clone().method.into_token_stream().to_string();
77        match callee.contains(self.callee_fn_name) {
78            true => self.callee_in_loop = true,
79            false => syn::visit_mut::visit_expr_method_call_mut(self, i),
80        }
81    }
82}
83
84struct CheckCalleeWithinLoop<'a> {
85    callee_fn_name: &'a str,
86    callee_in_loop: bool,
87}
88
89impl VisitMut for CheckCalleeWithinLoop<'_> {
90    fn visit_expr_mut(&mut self, i: &mut Expr) {
91        let mut helper = CheckCalleeWithinLoopHelper {
92            callee_fn_name: self.callee_fn_name,
93            callee_in_loop: self.callee_in_loop,
94        };
95        match i {
96            Expr::ForLoop(l) => {
97                helper.visit_expr_for_loop_mut(l);
98                if helper.callee_in_loop {
99                    self.callee_in_loop = true
100                };
101            }
102            Expr::Loop(l) => {
103                helper.visit_expr_loop_mut(l);
104                if helper.callee_in_loop {
105                    self.callee_in_loop = true
106                };
107            }
108            Expr::While(l) => {
109                helper.visit_expr_while_mut(l);
110                if helper.callee_in_loop {
111                    self.callee_in_loop = true
112                };
113            }
114
115            _ => syn::visit_mut::visit_expr_mut(self, i),
116        }
117    }
118}
119
120struct CallerVisitor<'a> {
121    found: bool,
122    caller_fn_name: &'a str,
123    callee_finder: &'a mut FindCallee<'a>,
124    callee_fn_name: &'a str,
125    callee_in_loop: bool,
126    // very simplified handling: if caller has loop and callee has break/continue but no loop
127    // assume it's control flow for caller otherwise, keep the same (assume control for callee loop)
128    caller_rety: &'a mut ReturnType,
129}
130
131impl VisitMut for CallerVisitor<'_> {
132    fn visit_impl_item_method_mut(&mut self, i: &mut ImplItemMethod) {
133        if self.callee_finder.found {
134            return;
135        }
136        debug!("finding caller in impl...");
137        let id = i.sig.ident.clone().to_string();
138        match id.contains(self.caller_fn_name) {
139            false => (),
140            true => {
141                debug!("found same id: {}...", id);
142                self.callee_finder.visit_impl_item_method_mut(i);
143                debug!(
144                    "found callee: {}? {}...",
145                    self.callee_finder.callee_fn_name, self.callee_finder.found
146                );
147                if !self.callee_finder.found {
148                    return;
149                }
150                self.caller_visitor(&mut i.sig, &mut i.block)
151            }
152        }
153        syn::visit_mut::visit_impl_item_method_mut(self, i);
154    }
155
156    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
157        if self.callee_finder.found {
158            return;
159        }
160
161        let id = i.sig.ident.clone().to_string();
162        match id.contains(self.caller_fn_name) {
163            false => (),
164            true => {
165                self.callee_finder.visit_item_fn_mut(i);
166                if !self.callee_finder.found {
167                    return;
168                }
169                self.caller_visitor(&mut i.sig, &mut i.block)
170            }
171        }
172    }
173
174    fn visit_trait_item_method_mut(&mut self, i: &mut TraitItemMethod) {
175        if self.callee_finder.found {
176            return;
177        }
178
179        let id = i.sig.ident.clone().to_string();
180        match id.contains(self.caller_fn_name) {
181            false => (),
182            true => {
183                self.callee_finder.visit_trait_item_method_mut(i);
184                if !self.callee_finder.found {
185                    return;
186                }
187                let _ = i
188                    .default
189                    .as_mut()
190                    .and_then(|block| Some(self.caller_visitor(&mut i.sig, block)));
191            }
192        }
193        syn::visit_mut::visit_trait_item_method_mut(self, i);
194    }
195}
196
197impl CallerVisitor<'_> {
198    fn caller_visitor(&mut self, sig: &mut Signature, block: &mut Block) {
199        self.found = true;
200        *self.caller_rety = sig.output.clone();
201        let mut helper = CheckCalleeWithinLoop {
202            callee_fn_name: self.callee_fn_name,
203            callee_in_loop: false,
204        };
205        helper.visit_block_mut(block);
206        self.callee_in_loop = helper.callee_in_loop;
207    }
208}
209
210enum RetTyQMark {
211    QMarkOption,
212    QMarkResult,
213}
214
215struct CalleeDeSugarQMark {
216    has_desugared: bool,
217    rety_qmark: RetTyQMark,
218}
219
220impl VisitMut for CalleeDeSugarQMark {
221    fn visit_expr_mut(&mut self, i: &mut Expr) {
222        match i {
223            Expr::Try(ExprTry { expr, .. }) => {
224                let inner = expr.as_mut().clone();
225                match self.rety_qmark {
226                    RetTyQMark::QMarkOption => {
227                        *i = syn::parse_str(
228                            format!(
229                                "match {} {{ Some(x) => x, None => return None }}",
230                                inner.into_token_stream().to_string()
231                            )
232                            .as_str(),
233                        )
234                        .unwrap();
235                    }
236                    RetTyQMark::QMarkResult => {
237                        *i = syn::parse_str(
238                            format!(
239                                "match {} {{ Ok(x) => x, Err(e) => return Err(e) }}",
240                                inner.into_token_stream().to_string()
241                            )
242                            .as_str(),
243                        )
244                        .unwrap();
245                    }
246                }
247                self.has_desugared = true;
248            }
249            _ => (),
250        }
251        syn::visit_mut::visit_expr_mut(self, i);
252    }
253}
254
255struct CalleeCheckReturn {
256    has_return: bool,
257}
258
259impl VisitMut for CalleeCheckReturn {
260    fn visit_expr_return_mut(&mut self, _e: &mut ExprReturn) {
261        debug!("has return?{:?}", _e);
262        self.has_return = true
263    }
264}
265
266struct CalleeCheckLoops {
267    has_break: bool,
268    has_continue: bool,
269}
270
271impl VisitMut for CalleeCheckLoops {
272    fn visit_expr_mut(&mut self, i: &mut Expr) {
273        match i {
274            Expr::Break(_) => self.has_break = true,
275            Expr::Continue(_) => self.has_continue = true,
276
277            // don't check for loop control within callee loops
278            Expr::ForLoop(_) => (),
279            Expr::Loop(_) => (),
280            Expr::While(_) => (),
281
282            _ => syn::visit_mut::visit_expr_mut(self, i),
283        }
284    }
285}
286
287#[derive(Debug)]
288struct CalleeCheckNCF<'a> {
289    found: bool,
290    callee_fn_name: &'a str,
291    caller_rety: ReturnType,
292    within_caller_loop: bool,
293    has_break: bool,
294    has_continue: bool,
295    has_return: bool,
296    num_inputs: usize,
297}
298
299impl VisitMut for CalleeCheckNCF<'_> {
300    fn visit_impl_item_method_mut(&mut self, i: &mut ImplItemMethod) {
301        let id = i.sig.ident.to_string();
302        match id.contains(self.callee_fn_name) {
303            false => (),
304            true => self.callee_check_ncf(i.sig.clone(), &mut i.block),
305        }
306        syn::visit_mut::visit_impl_item_method_mut(self, i);
307    }
308
309    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
310        let id = i.sig.ident.to_string();
311        match id.contains(self.callee_fn_name) {
312            false => (),
313            true => self.callee_check_ncf(i.sig.clone(), &mut i.block),
314        }
315    }
316
317    fn visit_trait_item_method_mut(&mut self, i: &mut TraitItemMethod) {
318        let id = i.sig.ident.to_string();
319        match id.contains(self.callee_fn_name) {
320            false => (),
321            true => {
322                let _ = i
323                    .default
324                    .as_mut()
325                    .and_then(|block| Some(self.callee_check_ncf(i.sig.clone(), block)));
326            }
327        }
328        syn::visit_mut::visit_trait_item_method_mut(self, i);
329    }
330}
331
332impl CalleeCheckNCF<'_> {
333    fn callee_check_ncf(&mut self, sig: Signature, block: &mut Block) {
334        self.found = true;
335
336        match &self.caller_rety {
337            ReturnType::Default => {}
338            ReturnType::Type(_, ty) => {
339                let mut rety = None;
340                if ty
341                    .as_ref()
342                    .clone()
343                    .into_token_stream()
344                    .to_string()
345                    .starts_with("Result")
346                {
347                    rety = Some(RetTyQMark::QMarkResult)
348                } else if ty
349                    .as_ref()
350                    .clone()
351                    .into_token_stream()
352                    .to_string()
353                    .starts_with("Option")
354                {
355                    rety = Some(RetTyQMark::QMarkOption)
356                }
357
358                match rety {
359                    None => (),
360                    Some(rety_qmark) => {
361                        debug!("desugaring...");
362                        let mut desugar_qmark = CalleeDeSugarQMark {
363                            has_desugared: false,
364                            rety_qmark,
365                        };
366                        desugar_qmark.visit_block_mut(block);
367                        debug!("desugaring...{}", desugar_qmark.has_desugared);
368                        self.has_return = desugar_qmark.has_desugared || self.has_return;
369                    }
370                }
371            }
372        }
373
374        self.num_inputs = sig.inputs.len();
375        let mut check_return = CalleeCheckReturn {
376            has_return: self.has_return,
377        };
378
379        let mut check_loops = CalleeCheckLoops {
380            has_break: self.has_break,
381            has_continue: self.has_continue,
382        };
383        block.stmts.iter_mut().for_each(|stmt| {
384            if !self.has_return {
385                check_return.visit_stmt_mut(stmt);
386            }
387            if self.within_caller_loop {
388                check_loops.visit_stmt_mut(stmt);
389            }
390        });
391        self.has_return = check_return.has_return;
392        self.has_break = check_loops.has_break;
393        self.has_continue = check_loops.has_continue;
394    }
395}
396
397struct MakeLastReturnBlkVisitor {}
398
399impl VisitMut for MakeLastReturnBlkVisitor {
400    fn visit_stmt_mut(&mut self, i: &mut Stmt) {
401        match i {
402            Stmt::Expr(e) => {
403                let re = quote!(result);
404                let e = e.clone();
405                *i = syn::parse_quote! {let #re = #e;}
406            }
407            _ => syn::visit_mut::visit_stmt_mut(self, i),
408        }
409    }
410}
411
412struct MakeBrkAndContVisitor<'a> {
413    callee_fn_name: &'a str,
414    success: bool,
415}
416
417impl VisitMut for MakeBrkAndContVisitor<'_> {
418    fn visit_expr_mut(&mut self, i: &mut Expr) {
419        // println!(
420        //     "expr make brk: {}",
421        //     i.clone().into_token_stream().to_string()
422        // );
423        match i {
424            Expr::Break(e) => {
425                match &e.expr {
426                    None => {}
427                    Some(_) => self.success = false,
428                }
429                let new_e_str = format!(
430                    "return {}{}::Break",
431                    ENUM_NAME,
432                    make_pascal_case(self.callee_fn_name)
433                );
434                let new_e: Expr = syn::parse_str(new_e_str.as_str()).unwrap();
435                *i = new_e
436            }
437            Expr::Continue(_) => {
438                let new_e_str = format!(
439                    "return {}{}::Continue",
440                    ENUM_NAME,
441                    make_pascal_case(self.callee_fn_name)
442                );
443                let new_e: Expr = syn::parse_str(new_e_str.as_str()).unwrap();
444                *i = new_e
445            }
446            _ => syn::visit_mut::visit_expr_mut(self, i),
447        }
448    }
449}
450
451struct MakeBrkAndCont<'a> {
452    callee_fn_name: &'a str,
453    success: bool,
454    already_did_return: bool,
455}
456
457impl VisitMut for MakeBrkAndCont<'_> {
458    fn visit_impl_item_method_mut(&mut self, i: &mut ImplItemMethod) {
459        let id = i.sig.ident.to_string();
460        match id.contains(self.callee_fn_name) {
461            false => (),
462            true => self.make_brk_and_cont(&mut i.sig, &mut i.block),
463        }
464        syn::visit_mut::visit_impl_item_method_mut(self, i);
465    }
466
467    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
468        let id = i.sig.ident.to_string();
469        match id.contains(self.callee_fn_name) {
470            false => (),
471            true => self.make_brk_and_cont(&mut i.sig, &mut i.block),
472        }
473    }
474
475    fn visit_trait_item_method_mut(&mut self, i: &mut TraitItemMethod) {
476        let id = i.sig.ident.to_string();
477        //println!("caller name: {}, at: {}", self.caller_fn_name, &id);
478        match id.contains(self.callee_fn_name) {
479            false => (),
480            true => {
481                let _ = i
482                    .default
483                    .as_mut()
484                    .and_then(|block| Some(self.make_brk_and_cont(&mut i.sig, block)));
485            }
486        }
487        syn::visit_mut::visit_trait_item_method_mut(self, i);
488    }
489}
490
491impl MakeBrkAndCont<'_> {
492    fn make_brk_and_cont(&mut self, sig: &mut Signature, block: &mut Block) {
493        let mut helper = MakeBrkAndContVisitor {
494            callee_fn_name: self.callee_fn_name,
495            success: self.success,
496        };
497        helper.visit_block_mut(block);
498        self.success = helper.success;
499        if !self.already_did_return {
500            let ident_str = format!("{}{}", ENUM_NAME, make_pascal_case(self.callee_fn_name));
501            let ident = Ident::new(ident_str.as_str(), Span::call_site());
502            let callee_rety = match sig.output.clone() {
503                ReturnType::Default => Type::Verbatim(quote! {()}),
504                ReturnType::Type(_, t) => t.as_ref().clone(),
505            };
506            let ty: Type = Type::Verbatim(quote! {#ident<#callee_rety>});
507            sig.output = ReturnType::Type(syn::parse_quote! {->}, Box::new(ty));
508
509            let ok = quote!(Ok);
510            match block.stmts.last_mut() {
511                None => {}
512                Some(s) => match s {
513                    Stmt::Expr(_) => {
514                        let mut helper = MakeLastReturnBlkVisitor {};
515                        helper.visit_stmt_mut(s);
516                        let re = quote!(result);
517                        let ret_stmt_expr: Expr = syn::parse_quote! {#ident::#ok(#re)};
518                        block.stmts.push(Stmt::Expr(ret_stmt_expr))
519                    }
520                    _ => {
521                        let ret_stmt_expr: Expr = syn::parse_quote! {#ident::#ok(())};
522                        block.stmts.push(Stmt::Expr(ret_stmt_expr))
523                    }
524                },
525            }
526        }
527    }
528}
529
530struct MakeReturn<'a> {
531    callee_fn_name: &'a str,
532    caller_rety: &'a Type,
533}
534
535impl VisitMut for MakeReturn<'_> {
536    fn visit_impl_item_method_mut(&mut self, i: &mut ImplItemMethod) {
537        let id = i.sig.ident.to_string();
538        match id.contains(self.callee_fn_name) {
539            false => (),
540            true => self.make_return(&mut i.sig, &mut i.block),
541        }
542        syn::visit_mut::visit_impl_item_method_mut(self, i);
543    }
544
545    fn visit_trait_item_method_mut(&mut self, i: &mut TraitItemMethod) {
546        let id = i.sig.ident.to_string();
547        //println!("caller name: {}, at: {}", self.caller_fn_name, &id);
548        match id.contains(self.callee_fn_name) {
549            false => (),
550            true => {
551                let _ = i
552                    .default
553                    .as_mut()
554                    .and_then(|block| Some(self.make_return(&mut i.sig, block)));
555            }
556        }
557        syn::visit_mut::visit_trait_item_method_mut(self, i);
558    }
559
560    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
561        let id = i.sig.ident.to_string();
562        match id.contains(self.callee_fn_name) {
563            false => (),
564            true => self.make_return(&mut i.sig, &mut i.block),
565        }
566    }
567}
568
569impl MakeReturn<'_> {
570    fn make_return(&mut self, sig: &mut Signature, block: &mut Block) {
571        let ident_str = format!("{}{}", ENUM_NAME, make_pascal_case(self.callee_fn_name));
572        let ident = Ident::new(ident_str.as_str(), Span::call_site());
573        let caller_rety = self.caller_rety.clone();
574        let callee_rety = match sig.output.clone() {
575            ReturnType::Default => Type::Verbatim(quote! {()}),
576            ReturnType::Type(_, t) => t.as_ref().clone(),
577        };
578        let ty: Type = Type::Verbatim(quote! {#ident<#callee_rety,#caller_rety>});
579        sig.output = ReturnType::Type(syn::parse_quote! {->}, Box::new(ty));
580
581        let ok = quote!(Ok);
582        match block.stmts.last_mut() {
583            None => {}
584            Some(s) => {
585                // println!("last stmt: {}", s.into_token_stream().to_string());
586                match s {
587                    Stmt::Expr(_) => {
588                        let mut helper = MakeLastReturnBlkVisitor {};
589                        helper.visit_stmt_mut(s);
590                        let re = quote!(result);
591                        let ret_stmt_expr: Expr = syn::parse_quote! {#ident::#ok(#re)};
592                        block.stmts.push(Stmt::Expr(ret_stmt_expr))
593                    }
594                    _ => {
595                        let ret_stmt_expr: Expr = syn::parse_quote! {#ident::#ok(())};
596                        block.stmts.push(Stmt::Expr(ret_stmt_expr))
597                    }
598                }
599            }
600        }
601    }
602}
603
604struct MakeCallerReturnHelper<'a> {
605    callee_fn_name: &'a str,
606}
607impl VisitMut for MakeCallerReturnHelper<'_> {
608    fn visit_expr_mut(&mut self, i: &mut Expr) {
609        debug!("expr: {:?}", i);
610        syn::visit_mut::visit_expr_mut(self, i);
611    }
612
613    fn visit_expr_return_mut(&mut self, i: &mut ExprReturn) {
614        let ident_str = format!("{}{}", ENUM_NAME, make_pascal_case(self.callee_fn_name));
615        let ident = Ident::new(ident_str.as_str(), Span::call_site());
616        let return_t = quote! {Return};
617        match i.expr.clone() {
618            None => {
619                let rety: Expr = syn::parse_quote! {#ident::#return_t(())};
620                i.expr = Some(Box::new(rety))
621            }
622            Some(e) => {
623                let e = e.as_ref().clone();
624                let rety: Expr = syn::parse_quote! {#ident::#return_t(#e)};
625                i.expr = Some(Box::new(rety));
626            }
627        }
628    }
629}
630
631struct MakeCallerReturn<'a> {
632    callee_fn_name: &'a str,
633}
634
635impl VisitMut for MakeCallerReturn<'_> {
636    fn visit_impl_item_method_mut(&mut self, i: &mut ImplItemMethod) {
637        let id = i.sig.ident.to_string();
638        match id.contains(self.callee_fn_name) {
639            true => {
640                debug!("found callee: {:?}", i);
641                let mut helper = MakeCallerReturnHelper {
642                    callee_fn_name: self.callee_fn_name,
643                };
644                helper.visit_impl_item_method_mut(i)
645            }
646            false => {}
647        }
648        syn::visit_mut::visit_impl_item_method_mut(self, i);
649    }
650    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
651        let id = i.sig.ident.to_string();
652        match id.contains(self.callee_fn_name) {
653            false => (),
654            true => {
655                debug!("found callee: {:?}", i);
656                let mut helper = MakeCallerReturnHelper {
657                    callee_fn_name: self.callee_fn_name,
658                };
659                helper.visit_item_fn_mut(i)
660            }
661        }
662    }
663    fn visit_trait_item_method_mut(&mut self, i: &mut TraitItemMethod) {
664        let id = i.sig.ident.to_string();
665        match id.contains(self.callee_fn_name) {
666            true => {
667                debug!("found callee: {:?}", i);
668                let mut helper = MakeCallerReturnHelper {
669                    callee_fn_name: self.callee_fn_name,
670                };
671                helper.visit_trait_item_method_mut(i);
672            }
673            false => {}
674        }
675        syn::visit_mut::visit_trait_item_method_mut(self, i);
676    }
677}
678
679struct MatchCallSiteHelper<'a> {
680    callee_fn_name: &'a str,
681    has_return: bool,
682    has_continue: bool,
683    has_break: bool,
684}
685
686impl VisitMut for MatchCallSiteHelper<'_> {
687    fn visit_expr_mut(&mut self, i: &mut Expr) {
688        // println!("visit expr: {}", i.into_token_stream().to_string());
689        match i {
690            Expr::Call(c) => {
691                let id = c.func.clone().as_ref().into_token_stream().to_string();
692                match id.contains(self.callee_fn_name) {
693                    true => {
694                        let e = i.clone().into_token_stream().to_string();
695                        let enum_name_fn = make_pascal_case(self.callee_fn_name);
696                        let match_str = format!(
697                            "match {} {{\n{} {} {} {}\n}}",
698                            e,
699                            format!("{}{}::Ok(x) => x,\n", ENUM_NAME, enum_name_fn),
700                            if self.has_return {
701                                format!("{}{}::Return(x) => return x,\n", ENUM_NAME, enum_name_fn)
702                            } else {
703                                "".to_string()
704                            },
705                            if self.has_break {
706                                format!("{}{}::Break => break,\n", ENUM_NAME, enum_name_fn)
707                            } else {
708                                "".to_string()
709                            },
710                            if self.has_continue {
711                                format!("{}{}::Continue => continue,", ENUM_NAME, enum_name_fn)
712                            } else {
713                                "".to_string()
714                            },
715                        );
716                        let match_expr: ExprMatch = syn::parse_str(match_str.as_str()).unwrap();
717                        *i = Expr::Match(match_expr)
718                    }
719                    false => syn::visit_mut::visit_expr_mut(self, i),
720                }
721            }
722            // NEED TO FIX TO INCLUDE CHECK FOR OTHER CALL SITE SUCH AS self. and Self::
723            _ => syn::visit_mut::visit_expr_mut(self, i),
724        }
725    }
726}
727
728struct MatchCallSite<'a> {
729    caller_fn_name: &'a str,
730    callee_finder: &'a mut FindCallee<'a>,
731    callee_fn_name: &'a str,
732    has_return: bool,
733    has_continue: bool,
734    has_break: bool,
735    enum_str: String,
736    added_enum: bool,
737}
738
739impl VisitMut for MatchCallSite<'_> {
740    fn visit_impl_item_method_mut(&mut self, i: &mut ImplItemMethod) {
741        if self.callee_finder.found {
742            return;
743        }
744
745        let id = i.sig.ident.to_string();
746        match id.contains(self.caller_fn_name) {
747            false => (),
748            true => {
749                self.callee_finder.visit_impl_item_method_mut(i);
750                if !self.callee_finder.found {
751                    return;
752                }
753                self.match_callsite(&mut i.block);
754            }
755        }
756        syn::visit_mut::visit_impl_item_method_mut(self, i);
757    }
758
759    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
760        if self.callee_finder.found {
761            return;
762        }
763
764        let id = i.sig.ident.to_string();
765        match id.contains(self.caller_fn_name) {
766            true => {
767                self.callee_finder.visit_item_fn_mut(i);
768                if !self.callee_finder.found {
769                    return;
770                }
771                self.match_callsite(&mut i.block);
772            }
773            false => {}
774        }
775        syn::visit_mut::visit_item_fn_mut(self, i);
776    }
777
778
779    fn visit_item_mod_mut(&mut self, i: &mut ItemMod) {
780        if i.clone()
781            .into_token_stream()
782            .to_string()
783            .contains(self.callee_fn_name)
784        {
785            match i.content.as_mut() {
786                None => {}
787                Some((_, items)) => {
788                    items.push(syn::parse_str(self.enum_str.as_str()).unwrap());
789                    self.added_enum = true;
790                }
791            }
792        }
793        syn::visit_mut::visit_item_mod_mut(self, i);
794    }
795
796    fn visit_trait_item_method_mut(&mut self, i: &mut TraitItemMethod) {
797        if self.callee_finder.found {
798            return;
799        }
800
801        let id = i.sig.ident.to_string();
802        //println!("caller name: {}, at: {}", self.caller_fn_name, &id);
803        match id.contains(self.caller_fn_name) {
804            false => (),
805            true => {
806                self.callee_finder.visit_trait_item_method_mut(i);
807                if !self.callee_finder.found {
808                    return;
809                }
810                let _ = i
811                    .clone()
812                    .default
813                    .as_mut()
814                    .and_then(|block| Some(self.match_callsite(block)));
815            }
816        }
817        syn::visit_mut::visit_trait_item_method_mut(self, i);
818    }
819}
820
821impl MatchCallSite<'_> {
822    fn match_callsite(&mut self, block: &mut Block) {
823        let mut helper = MatchCallSiteHelper {
824            callee_fn_name: self.callee_fn_name,
825            has_return: self.has_return,
826            has_continue: self.has_continue,
827            has_break: self.has_break,
828        };
829        helper.visit_block_mut(block);
830    }
831}
832
833#[derive(Debug)]
834pub struct NonLocalControlFlowResult {
835    pub success: bool,
836    #[allow(dead_code)]
837    pub has_return: bool,
838    #[allow(dead_code)]
839    pub has_continue: bool,
840    #[allow(dead_code)]
841    pub has_break: bool,
842    #[allow(dead_code)]
843    pub num_inputs: usize,
844    #[allow(dead_code)]
845    pub output_code: String,
846}
847
848pub fn inner_make_controls(
849    input_code: String,
850    callee_fn_name: &str,
851    caller_fn_name: &str,
852) -> NonLocalControlFlowResult {
853    let mut success: bool = true;
854    debug!("debugging controller...");
855    let file_content: String = input_code;
856
857    let mut file: File = parse_str::<File>(file_content.clone().as_str())
858        .map_err(|e| {
859            let s = format!("THERE IS AN ERROR HERE NOT PARSED: {:?}", e);
860            println!("errored: {}", &s);
861            s
862        })
863        .unwrap();
864
865    let mut caller_rety: ReturnType = ReturnType::Default;
866    let mut caller_visitor: CallerVisitor<'_> = CallerVisitor {
867        found: false,
868        caller_fn_name,
869        callee_finder: &mut FindCallee {
870            found: false,
871            callee_fn_name,
872        },
873        callee_fn_name,
874        callee_in_loop: false,
875        caller_rety: &mut caller_rety,
876    };
877    caller_visitor.visit_file_mut(&mut file);
878    if !caller_visitor.found {
879        debug!("did not find caller");
880        return NonLocalControlFlowResult {
881            success: false,
882            has_return: false,
883            has_continue: false,
884            has_break: false,
885            num_inputs: 0,
886            output_code: file_content,
887        };
888    }
889
890    let mut callee_visitor = CalleeCheckNCF {
891        found: false,
892        callee_fn_name,
893        caller_rety: caller_visitor.caller_rety.clone(),
894        within_caller_loop: caller_visitor.callee_in_loop,
895        has_break: false,
896        has_continue: false,
897        has_return: false,
898        num_inputs: 0,
899    };
900    callee_visitor.visit_file_mut(&mut file);
901
902    if !callee_visitor.found {
903        debug!("did not find callee");
904        return NonLocalControlFlowResult {
905            success: false,
906            has_return: false,
907            has_continue: false,
908            has_break: false,
909            num_inputs: 0,
910            output_code: file_content,
911        };
912    }
913
914    debug!("callee_visitor: {:?}", callee_visitor);
915    if callee_visitor.has_return || callee_visitor.has_continue || callee_visitor.has_break {
916        let caller_rety = match caller_visitor.caller_rety {
917            ReturnType::Default => Type::Verbatim(quote! {()}),
918            ReturnType::Type(_, t) => t.as_ref().clone(),
919        };
920        let mut already_did_return = false;
921
922        if callee_visitor.has_return {
923            debug!("has return!");
924            let mut make_ret = MakeReturn {
925                callee_fn_name,
926                caller_rety: &caller_rety,
927            };
928            make_ret.visit_file_mut(&mut file);
929
930            let mut make_caller_ret = MakeCallerReturn { callee_fn_name };
931            make_caller_ret.visit_file_mut(&mut file);
932            already_did_return = true;
933        }
934
935        if callee_visitor.has_break || callee_visitor.has_continue {
936            debug!("has break {} or cont {}",callee_visitor.has_break, callee_visitor.has_continue);
937            let mut make_brk_and_cont = MakeBrkAndCont {
938                callee_fn_name,
939                success,
940                already_did_return,
941            };
942            make_brk_and_cont.visit_file_mut(&mut file);
943            success = make_brk_and_cont.success
944        }
945
946        let ident_str = format!("{}{}", ENUM_NAME, make_pascal_case(callee_fn_name));
947        let enum_str = format!(
948            "enum {}<A{}> \n{{Ok(A),\n{}{}{}}}",
949            ident_str,
950            if callee_visitor.has_return { ", B" } else { "" },
951            if callee_visitor.has_return {
952                "Return(B),\n"
953            } else {
954                ""
955            },
956            if callee_visitor.has_break {
957                "Break,\n"
958            } else {
959                ""
960            },
961            if callee_visitor.has_continue {
962                "Continue,\n"
963            } else {
964                ""
965            },
966        );
967
968        let mut caller_matcher = MatchCallSite {
969            caller_fn_name,
970            callee_finder: &mut FindCallee {
971                found: false,
972                callee_fn_name,
973            },
974            callee_fn_name,
975            has_return: callee_visitor.has_return,
976            has_continue: callee_visitor.has_continue,
977            has_break: callee_visitor.has_break,
978            enum_str: enum_str.clone(),
979            added_enum: false,
980        };
981        caller_matcher.visit_file_mut(&mut file);
982
983        if !caller_matcher.added_enum {
984            file.items.push( parse_str( enum_str.as_str() ).unwrap());
985        }
986    }
987    let src: String = file.into_token_stream().to_string();
988    NonLocalControlFlowResult {
989        success,
990        has_return: callee_visitor.has_return,
991        has_continue: callee_visitor.has_continue,
992        has_break: callee_visitor.has_break,
993        num_inputs: callee_visitor.num_inputs,
994        output_code: format_source( &src ),
995    }
996}
997
998pub fn make_controls(
999    input: ControllerInput,
1000) -> Result<String, ControllerError> {
1001    let res: NonLocalControlFlowResult = inner_make_controls(
1002        input.input_code,
1003        &input.callee_fn_name,
1004        &input.caller_fn_name
1005    );
1006    debug!("result: {:?}", res);
1007    if res.success {
1008        Ok(res.output_code)
1009    } else {
1010        Err(ControllerError::MakeControlsFailed)
1011    }
1012}