1use std::{
10 collections::HashMap,
11 iter::Peekable,
12 ops::{Deref, DerefMut},
13};
14
15use syn::{
16 Arm, Attribute, Block, Error, Expr, ExprBlock, ExprBreak, ExprForLoop, ExprIf, ExprLoop,
17 ExprMatch, ExprReturn, ExprWhile, ExprYield, FnArg, GenericArgument, GenericParam, Generics,
18 Ident, Lifetime, LifetimeParam, Local, Macro, Pat, PatType, Stmt, Token, Type, Visibility,
19 parenthesized,
20 parse::{Parse, ParseStream, discouraged::Speculative},
21 parse_macro_input, parse_quote, parse_quote_spanned, parse2,
22 punctuated::Punctuated,
23 spanned::Spanned,
24 token::{Comma, Paren, RArrow, Semi, Unsafe},
25 visit::{self, Visit},
26 visit_mut::{self, VisitMut},
27};
28
29use proc_macro2::{Span, TokenStream};
30use quote::{ToTokens, format_ident, quote, quote_spanned};
31
32const COMPLETED_STATE_ID: usize = 0;
33const LIFETIME_STR: &str = "'__karutin_lifetime__";
34const STATE_LOOP_LABEL_STR: &str = "'__karutin_state_loop__";
35const LET_BINDING_IDENT_STR: &str = "__karutin_let_binding__";
36
37macro_rules! format_stack_ident {
38 ($i:expr) => {
39 format_ident!("__{}_karutin_stack__", $i)
40 };
41}
42
43macro_rules! format_context_ident {
44 ($i:expr) => {
45 format_ident!("__{}_karutin_ctx__", $i)
46 };
47}
48
49macro_rules! format_generic_ident {
50 ($i:expr) => {
51 format_ident!("T{}", $i)
52 };
53}
54
55macro_rules! format_field_ident {
56 ($i:expr) => {
57 format_ident!("f{}", $i)
58 };
59}
60
61fn is_yield_stmt(stmt: &Stmt) -> bool {
62 match stmt {
63 Stmt::Expr(Expr::Yield(_), _) => true,
64 _ => false,
65 }
66}
67
68fn is_loop_stmt(stmt: &Stmt) -> bool {
69 match stmt {
70 Stmt::Expr(Expr::Loop(_), _) => true,
71 _ => false,
72 }
73}
74
75fn is_break_stmt(stmt: &Stmt) -> bool {
76 match stmt {
77 Stmt::Expr(Expr::Break(_), _) => true,
78 _ => false,
79 }
80}
81
82#[derive(Default)]
83struct PotentialYieldCheck(bool);
84
85impl<'a> Visit<'a> for PotentialYieldCheck {
86 fn visit_expr_yield(&mut self, node: &'a syn::ExprYield) {
87 self.0 = true;
88 visit::visit_expr_yield(self, node);
89 }
90}
91
92fn is_potential_yield_stmt(stmt: &Stmt) -> bool {
93 let mut check = PotentialYieldCheck::default();
94 check.visit_stmt(stmt);
95 check.0
96}
97
98fn convert_yield(expr_yield: &mut ExprYield) -> ExprBreak {
99 let expr = expr_yield
100 .expr
101 .take()
102 .unwrap_or_else(|| Box::new(parse_quote!(())));
103
104 let state_loop_label = Lifetime::new(STATE_LOOP_LABEL_STR, Span::call_site());
105 let span = expr_yield.yield_token.span;
106
107 syn::parse_quote_spanned! {span=>
108 break #state_loop_label ({ ::karutin::KarutinState::Yielded( #expr ) })
109 }
110}
111
112fn convert_return(expr_return: &mut ExprReturn) -> ExprBlock {
113 let expr = expr_return
114 .expr
115 .take()
116 .unwrap_or_else(|| Box::new(parse_quote!(())));
117
118 let state_loop_label = Lifetime::new(STATE_LOOP_LABEL_STR, Span::call_site());
119 let span = expr_return.return_token.span;
120
121 syn::parse_quote_spanned! {span=>{
122 self.states[#COMPLETED_STATE_ID] = 1;
123 break #state_loop_label ({ ::karutin::KarutinState::Returned( #expr ) })
124 }}
125}
126
127fn continue_state_stmt(state_id: usize) -> Stmt {
128 syn::parse_quote! {
129 self.states[#state_id] += 1;
130 }
131}
132
133fn create_state_arm(state: usize, block: Block) -> Arm {
134 create_arm(syn::parse_quote! { #state }, block)
135}
136
137fn create_arm(pat: Pat, block: Block) -> Arm {
138 syn::parse_quote! {
139 #pat => #block
140 }
141}
142
143fn chunk_by_statefuls(stmts: Vec<Stmt>) -> Vec<Vec<Stmt>> {
144 stmts
145 .chunk_by(|s1: &Stmt, s2: &Stmt| {
146 !is_yield_stmt(s1)
147 && !is_potential_yield_stmt(s2)
148 && !is_loop_stmt(s1)
149 && !is_loop_stmt(s2)
150 })
151 .map(|c| c.to_owned())
152 .collect::<Vec<Vec<Stmt>>>()
153}
154
155#[rustfmt::skip]
156fn attach_state_match_arms(
157 state_id: usize,
158 match_expr: &mut ExprMatch,
159 blocks: Vec<Block>,
160) {
161 let block_count = blocks.len();
162
163 for (i, mut block) in blocks.into_iter().enumerate() {
164 let stmt_count = block.stmts.len();
165 block.stmts.push(continue_state_stmt(state_id));
166
167 if is_break_stmt(&block.stmts[stmt_count - 1]) {
168 block.stmts.swap(stmt_count, stmt_count - 1);
169 } else if i != block_count - 1 {
170 let state_loop_label = Lifetime::new(
171 STATE_LOOP_LABEL_STR,
172 Span::call_site(),
173 );
174
175 block.stmts.push(parse_quote!{
176 continue #state_loop_label;
177 });
178 }
179
180 let arm = create_state_arm(i, block);
181 match_expr.arms.push(arm);
182 }
183
184 let fall_arm = create_arm(
185 syn::parse_quote! { _ },
186 syn::parse_quote! { { } }
187 );
188
189 match_expr.arms.push(fall_arm);
190}
191
192const MACRO_USAGE_ERR: &str = "\
193 Macros can not be used in Karutin functions!\n\n\
194 Macros are not expandable in procedure macros,\n\
195 so when code lowering, Karutin does not know what is going on in them.\n\
196 Because of this the state machine and stack management do not work.\n\
197 This may be solved in the future.
198";
199
200#[derive(Default)]
201struct MacroSpans {
202 inner: Vec<Span>,
203}
204
205impl MacroSpans {
206 pub fn into_inner(self) -> Vec<Span> {
207 self.inner
208 }
209}
210
211impl<'a> Visit<'a> for MacroSpans {
212 fn visit_macro(&mut self, macro_: &'a Macro) {
213 self.inner.push(macro_.span());
214 }
215}
216
217fn check_blocks_macro_usage(karutin_fn: &KarutinFn) -> Option<Error> {
218 let mut macro_usage = MacroSpans::default();
219 macro_usage.visit_block(&karutin_fn.block);
220
221 macro_usage
222 .into_inner()
223 .into_iter()
224 .map(|span| Error::new(span, MACRO_USAGE_ERR))
225 .reduce(|mut acc, error| {
226 acc.combine(error);
227 acc
228 })
229}
230
231const LET_BINDING_MUTABILITY_ERR: &str = "\
232 Locals can not be immutable!\n\n\
233 Because of the way stack is managed, locals are always moveable/mutable.\n\
234 To prevent this, Karutin can follow moving/mutability, which is hard and even impossible for same cases.\n\
235 So for now, we want to you explicitly define locals mutable for you to know what they are.\n\
236 This may be solved in the future.
237";
238
239const COMPLEX_PATTERN_ERR: &str = "\
240 Karutin functions can not have complex patterns!\n\n\
241 Code lowering for this complex patterns and stack management are hard to implement.\n\
242 So for now, only simple patterns are available.\n\
243 This may be solved in the future.
244";
245
246#[derive(PartialEq, Eq, Hash)]
247enum RestrictionError {
248 Mutability,
249 ComplexPattern,
250}
251
252impl RestrictionError {
253 pub const fn get_message(&self) -> &str {
254 match self {
255 RestrictionError::Mutability => LET_BINDING_MUTABILITY_ERR,
256 RestrictionError::ComplexPattern => COMPLEX_PATTERN_ERR,
257 }
258 }
259}
260
261#[derive(Default)]
262struct RestrictionErrors(Vec<(RestrictionError, Span)>);
263
264impl RestrictionErrors {
265 pub fn into_inner(self) -> Vec<(RestrictionError, Span)> {
266 self.0
267 }
268
269 fn check_general_pattern(&mut self, pat: &Pat) {
270 use RestrictionError::{ComplexPattern, Mutability};
271
272 match pat {
273 Pat::Ident(pat_ident) => {
274 if pat_ident.mutability.is_none() {
275 self.push((Mutability, pat.span()));
276 };
277
278 if let Some((_, subpat)) = &pat_ident.subpat {
279 self.push((ComplexPattern, subpat.span()));
280 }
281 },
282 Pat::Wild(_) => {
283 },
285 _ => self.push((ComplexPattern, pat.span())),
286 }
287 }
288}
289
290impl Deref for RestrictionErrors {
291 type Target = Vec<(RestrictionError, Span)>;
292
293 fn deref(&self) -> &Self::Target {
294 &self.0
295 }
296}
297
298impl DerefMut for RestrictionErrors {
299 fn deref_mut(&mut self) -> &mut Self::Target {
300 &mut self.0
301 }
302}
303
304impl<'a> Visit<'a> for RestrictionErrors {
305 fn visit_local(&mut self, node: &'a syn::Local) {
306 use RestrictionError::ComplexPattern;
307
308 if let Some(init) = &node.init
309 && let Some(diverge) = &init.diverge
310 {
311 self.push((ComplexPattern, diverge.1.span()));
312 }
313
314 self.check_general_pattern(&node.pat);
315 visit::visit_local(self, node);
316 }
317
318 fn visit_expr_for_loop(&mut self, node: &'a syn::ExprForLoop) {
319 self.check_general_pattern(node.pat.as_ref());
320 visit::visit_expr_for_loop(self, node);
321 }
322}
323
324fn check_restriction_errors(karutin_fn: &KarutinFn) -> Option<Error> {
325 let mut restriction_errors = RestrictionErrors::default();
326 restriction_errors.visit_block(&karutin_fn.block);
327
328 let errors = restriction_errors
329 .into_inner()
330 .into_iter()
331 .map(|(err_type, span)| Error::new(span, err_type.get_message()));
332
333 errors.reduce(|mut acc, error| {
334 acc.combine(error);
335 acc
336 })
337}
338
339fn create_stack_generics(count: usize) -> TokenStream {
340 let mut stream = TokenStream::new();
341
342 for i in 0..count {
343 let ty: Ident = format_generic_ident!(i);
344 stream.extend(quote! { #ty, });
345 }
346
347 stream
348}
349
350fn create_empty_stack_generics(count: usize) -> TokenStream {
351 let mut stream = TokenStream::new();
352
353 for _ in 0..count {
354 stream.extend(quote! { _, });
355 }
356
357 stream
358}
359
360fn create_stack_field_idents(count: usize) -> impl Iterator<Item = Ident> {
361 (0..count).map(|i| format_field_ident!(i)).into_iter()
362}
363
364fn create_stack_fields(count: usize) -> TokenStream {
365 let mut stream = TokenStream::new();
366
367 for i in 0..count {
368 let ident: Ident = format_field_ident!(i);
369 let ty: Ident = format_generic_ident!(i);
370
371 stream.extend(quote! {
372 #ident: #ty,
373 });
374 }
375
376 stream
377}
378
379#[derive(Default)]
380struct Transpiler;
381
382impl Transpiler {
383 const SKIP_ATTR_STR: &str = "__skip_transpile__";
384 const YIELD_FROM_ATTR_STR: &str = "__yield_from__";
385
386 fn create_attr(ident: &str) -> Attribute {
387 let _ident = Ident::new(ident, Span::mixed_site());
388 parse_quote! { #[#_ident] }
389 }
390
391 pub fn create_skip_attr() -> Attribute {
392 Self::create_attr(Self::SKIP_ATTR_STR)
393 }
394
395 pub fn create_yield_from_attr() -> Attribute {
396 Self::create_attr(Self::YIELD_FROM_ATTR_STR)
397 }
398
399 fn get_attr_index(attrs: &Vec<Attribute>, ident: &str) -> Option<usize> {
400 attrs.iter().enumerate().find_map(|(i, attr)| {
401 attr.path().is_ident(ident);
402 Some(i)
403 })
404 }
405
406 fn remove_attr(attrs: &mut Vec<Attribute>, ident: &str) -> bool {
407 if let Some(i) = Self::get_attr_index(attrs, ident) {
408 attrs.remove(i);
409 true
410 } else {
411 false
412 }
413 }
414
415 fn remove_skip_attr(attrs: &mut Vec<Attribute>) -> bool {
416 Self::remove_attr(attrs, Self::SKIP_ATTR_STR)
417 }
418
419 fn remove_yield_from_attr(attrs: &mut Vec<Attribute>) -> bool {
420 Self::remove_attr(attrs, Self::YIELD_FROM_ATTR_STR)
421 }
422
423 fn transpile_for_loop(node: &mut ExprForLoop) -> ExprBlock {
442 let pat = &node.pat;
443 let expr = &node.expr;
444 let body = &node.body;
445 let label = node.label.as_ref();
446
447 let skip_attr = Self::create_skip_attr();
448 let mut for_loop_: ExprForLoop = parse_quote! {
449 #skip_attr
450 for _ in [(); 0] {}
451 };
452
453 let mut loop_: ExprLoop = parse_quote! {
454 loop {
455 let #pat = match iter.next() {
456 Some(v) => {v},
457 None => break,
458 };
459 #body
460 }
461 };
462
463 if let Some(label) = label {
464 loop_.label = Some(label.clone());
465 }
466
467 for_loop_.for_token.span = node.for_token.span;
468 for_loop_.in_token.span = node.in_token.span;
469
470 parse_quote! {{
471 #for_loop_
472 let mut iter = ::std::iter::IntoIterator::into_iter(#expr);
473 #loop_
474 }}
475 }
476
477 fn transpile_while_loop(node: &mut ExprWhile) -> ExprBlock {
494 let expr = &node.cond;
495 let body = &node.body;
496 let label = node.label.as_ref();
497
498 let skip_attr = Self::create_skip_attr();
499 let mut while_loop_: ExprWhile = parse_quote! {
500 #skip_attr
501 while false {}
502 };
503
504 let mut loop_: ExprLoop = parse_quote! {
505 loop {
506 if #expr #body
507 else {
508 break;
509 }
510 }
511 };
512
513 if let Some(label) = label {
514 loop_.label = Some(label.clone());
515 }
516
517 while_loop_.while_token.span = node.while_token.span;
518
519 parse_quote! {{
520 #while_loop_
521 #loop_
522 }}
523 }
524
525 fn transpile_yield_from(node: &mut ExprYield) -> ExprBlock {
540 let expr = &node.expr;
541
542 parse_quote! {{
543 for val in ::karutin::into_value_iter!(#expr) {
544 yield val
545 }
546 }}
547 }
548}
549
550impl VisitMut for Transpiler {
551 fn visit_expr_mut(&mut self, node: &mut syn::Expr) {
552 match node {
553 Expr::ForLoop(expr_for_loop) => {
554 if Self::remove_skip_attr(&mut expr_for_loop.attrs) {
555 return;
556 };
557
558 *node = Expr::Block(Self::transpile_for_loop(expr_for_loop));
559 visit_mut::visit_expr_mut(self, node);
560 },
561 Expr::While(expr_while) => {
562 *node = Expr::Block(Self::transpile_while_loop(expr_while));
563 visit_mut::visit_expr_mut(self, node);
564 },
565 Expr::Yield(expr_yield) => {
566 if Self::remove_yield_from_attr(&mut expr_yield.attrs) {
567 *node = Expr::Block(Self::transpile_yield_from(expr_yield));
568 visit_mut::visit_expr_mut(self, node);
569 } else {
570 visit_mut::visit_expr_yield_mut(self, expr_yield);
571 }
572 },
573 _ => {},
574 }
575 }
576}
577
578fn transpile(node: &mut Block) {
579 let mut transpiler = Transpiler::default();
580 transpiler.visit_block_mut(node);
581}
582
583#[derive(Default)]
584struct StateMachine {
585 pub state_count: usize,
586}
587
588impl StateMachine {
589 fn create_state_match_expr(&mut self, blocks: Vec<Block>) -> ExprMatch {
590 let state_id = self.state_count;
591 self.state_count += 1;
592
593 let mut match_expr: ExprMatch = syn::parse_quote! {
594 match self.states[#state_id] {}
595 };
596
597 attach_state_match_arms(state_id, &mut match_expr, blocks);
598
599 match_expr
600 }
601
602 fn visit_block_stmts(&mut self, block: &mut Block) {
603 for it in &mut block.stmts {
604 self.visit_stmt_mut(it);
605 }
606 }
607}
608
609impl VisitMut for StateMachine {
610 fn visit_expr_loop_mut(&mut self, node: &mut syn::ExprLoop) {
611 let start = self.state_count;
612 visit_mut::visit_expr_loop_mut(self, node);
613 let end = self.state_count;
614
615 let if_expr: ExprIf = syn::parse_quote! {
616 if let Some(states) = self.states.get_mut(#start..#end) {
617 states.fill(0);
618 }
619 };
620
621 let expr = Expr::If(if_expr);
622 let stmt = Stmt::Expr(expr, None);
623
624 node.body.stmts.push(stmt);
625 }
626
627 fn visit_expr_mut(&mut self, node: &mut syn::Expr) {
628 match node {
629 Expr::Yield(expr_yield) => {
630 if let Some(expr) = &mut expr_yield.expr {
631 self.visit_expr_mut(expr);
632 }
633
634 *node = Expr::Break(convert_yield(expr_yield));
635 },
636 Expr::Return(expr_return) => {
637 if let Some(expr) = &mut expr_return.expr {
638 self.visit_expr_mut(expr);
639 }
640
641 *node = Expr::Block(convert_return(expr_return));
642 },
643 _ => {},
644 }
645
646 visit_mut::visit_expr_mut(self, node);
647 }
648
649 fn visit_block_mut(&mut self, node: &mut syn::Block) {
650 let stmts = std::mem::replace(&mut node.stmts, vec![]);
651 let mut chunks = chunk_by_statefuls(stmts);
652
653 if chunks.len() == 1 {
654 if let Some(stmt) = chunks[0].last()
655 && !is_yield_stmt(stmt)
656 {
657 std::mem::swap(&mut node.stmts, &mut chunks[0]);
658
659 self.visit_block_stmts(node);
660 return;
661 }
662 }
663
664 let mut blocks = chunks
665 .into_iter()
666 .map(|chunk| Block {
667 brace_token: Default::default(),
668 stmts: chunk,
669 })
670 .collect::<Vec<Block>>();
671
672 let mut last_block = blocks.pop();
673
674 if let Some(last_block_ref) = &last_block
675 && last_block_ref.stmts.len() == 1
676 && is_yield_stmt(&last_block_ref.stmts[0])
677 {
678 blocks.push(last_block.take().unwrap());
679 }
680
681 for mut block in blocks.iter_mut() {
682 self.visit_block_stmts(&mut block);
683 }
684
685 let match_expr = self.create_state_match_expr(blocks);
686
687 let expr = Expr::Match(match_expr);
688 let stmt = Stmt::Expr(expr, None);
689
690 node.stmts.push(stmt);
691
692 if let Some(mut last_block) = last_block {
693 self.visit_block_stmts(&mut last_block);
694
695 let block_expr = ExprBlock {
696 attrs: vec![],
697 label: None,
698 block: last_block,
699 };
700
701 let expr = Expr::Block(block_expr);
702 let stmt = Stmt::Expr(expr, None);
703
704 node.stmts.push(stmt);
705 }
706 }
707}
708
709fn sift_states(node: &mut Block) -> usize {
710 let mut state_machine = StateMachine::default();
711
712 state_machine.state_count += 1;
713 state_machine.visit_block_mut(node);
714
715 state_machine.state_count
716}
717
718#[derive(Default)]
719struct StackScope(HashMap<String, usize>);
720
721#[derive(Default)]
722struct StackBuilder {
723 pub scopes: Vec<StackScope>,
724 pub local_count: usize,
725}
726
727impl StackBuilder {
728 fn lookup_local(&self, ident: &Ident) -> Option<usize> {
729 let ident_str = ident.to_string();
730 let mut result = Option::<usize>::None;
731
732 for scope in self.scopes.iter().rev() {
733 if result.is_some() {
734 break;
735 }
736
737 let ret = scope.0.get(&ident_str);
738
739 if let Some(id) = ret {
740 result = Some(*id);
741 }
742 }
743
744 result
745 }
746
747 fn insert_local(&mut self, ident: &Ident) -> usize {
748 let ident_str = ident.to_string();
749 let result = self.local_count;
750
751 let last = self.scopes.last_mut().unwrap();
752
753 last.0.insert(ident_str, result);
754
755 self.local_count += 1;
756 result
757 }
758
759 fn convert_expr(&self, expr: &mut Expr) -> bool {
760 if let Expr::Path(expr_path) = expr
761 && expr_path.path.segments.len() == 1
762 {
763 let ident = &expr_path.path.segments[0].ident;
764
765 if let Some(id) = self.lookup_local(ident) {
766 let mut field_ident = format_field_ident!(id);
767 field_ident.set_span(ident.span());
768
769 let new_expr: Expr = parse_quote_spanned! {ident.span()=>
770 stack.#field_ident
771 };
772
773 *expr = new_expr;
774 return true;
775 }
776 }
777
778 false
779 }
780}
781
782impl VisitMut for StackBuilder {
783 fn visit_expr_mut(&mut self, node: &mut Expr) {
784 if !self.convert_expr(node) {
785 visit_mut::visit_expr_mut(self, node);
786 }
787 }
788
789 fn visit_local_mut(&mut self, node: &mut Local) {
790 if let Pat::Ident(pat_ident) = &mut node.pat {
791 let id = self.insert_local(&pat_ident.ident);
792 let ident_span = pat_ident.ident.span();
793
794 pat_ident.ident = Ident::new(LET_BINDING_IDENT_STR, ident_span);
795
796 if let Some(init) = &mut node.init {
797 let mut field_ident = format_field_ident!(id);
798 field_ident.set_span(ident_span);
799
800 let expr = &init.expr;
801 let block_expr: ExprBlock = parse_quote_spanned! {ident_span=>{
802 stack.#field_ident = #expr
803 }};
804
805 init.expr = Box::new(block_expr.into());
806 self.visit_expr_mut(&mut init.expr);
807 }
808 } else {
809 visit_mut::visit_local_mut(self, node);
810 }
811 }
812
813 fn visit_block_mut(&mut self, node: &mut Block) {
814 self.scopes.push(Default::default());
815 visit_mut::visit_block_mut(self, node);
816 self.scopes.pop();
817 }
818}
819
820fn build_stack(node: &mut Block) -> usize {
821 let mut builder = StackBuilder::default();
822
823 builder.visit_block_mut(node);
824
825 builder.local_count
826}
827
828struct KarutinReturnType {
829 pub yield_type: Box<Type>,
830 pub return_type: Box<Type>,
831}
832
833impl Parse for KarutinReturnType {
834 fn parse(input: ParseStream) -> syn::Result<Self> {
835 Ok(Self {
836 yield_type: {
837 input.parse::<Token![->]>()?;
838 input.parse()?
839 },
840 return_type: {
841 input.parse::<Token![..]>()?;
842 input.parse()?
843 },
844 })
845 }
846}
847
848impl ToTokens for KarutinReturnType {
849 fn to_tokens(&self, tokens: &mut TokenStream) {
850 let yield_type = &self.yield_type;
851 let return_type = &self.return_type;
852
853 let args: Punctuated<GenericArgument, Token![,]> = parse_quote! {
854 Yield = #yield_type,
855 Return = #return_type
856 };
857
858 args.to_tokens(tokens);
859 }
860}
861
862struct KarutinParameters {
863 pub paren_token: Paren,
864 pub inputs: Punctuated<FnArg, Comma>,
865}
866
867impl KarutinParameters {
868 pub fn into_pat_type(self) -> PatType {
869 let inputs_iter = self.inputs.into_iter();
870
871 let fm_closure = |arg| match arg {
872 FnArg::Typed(pt) => Some((pt.pat, pt.ty)),
873 _ => None,
874 };
875
876 let pairs: (Vec<Box<Pat>>, Vec<Box<Type>>) = inputs_iter.filter_map(fm_closure).unzip();
877
878 let (pats, types) = pairs;
879
880 parse_quote! {
881 ( #( #pats ),* ): ( #( #types ),* )
882 }
883 }
884}
885
886impl Parse for KarutinParameters {
887 fn parse(input: ParseStream) -> syn::Result<Self> {
888 let content;
889
890 Ok(Self {
891 paren_token: parenthesized!(content in input),
892 inputs: content.parse_terminated(FnArg::parse, Token![,])?,
893 })
894 }
895}
896
897impl ToTokens for KarutinParameters {
898 fn to_tokens(&self, tokens: &mut TokenStream) {
899 self.paren_token.surround(tokens, |tokens| {
900 self.inputs.to_tokens(tokens);
901 });
902 }
903}
904
905struct KarutinSignature {
906 pub unsafety: Option<Unsafe>,
907 pub fn_token: Token![fn],
908 pub ident: Ident,
909 pub generics: Generics,
910 pub parameters: KarutinParameters,
911 pub output: KarutinReturnType,
912}
913
914impl Parse for KarutinSignature {
915 fn parse(input: ParseStream) -> syn::Result<Self> {
916 Ok(Self {
917 unsafety: input.parse()?,
918 fn_token: input.parse()?,
919 ident: input.parse()?,
920 generics: input.parse()?,
921 parameters: input.parse()?,
922 output: input.parse()?,
923 })
924 }
925}
926
927impl ToTokens for KarutinSignature {
928 fn to_tokens(&self, tokens: &mut TokenStream) {
929 self.unsafety.to_tokens(tokens);
930 self.fn_token.to_tokens(tokens);
931 self.ident.to_tokens(tokens);
932 self.generics.to_tokens(tokens);
933 self.parameters.to_tokens(tokens);
934
935 let type_stream = &mut TokenStream::new();
936
937 self.parameters.to_tokens(type_stream);
938 Comma::default().to_tokens(type_stream);
939 self.output.to_tokens(type_stream);
940
941 let type_: Type = parse_quote! {
942 impl ::karutin::Karutin<#type_stream>
943 };
944
945 RArrow::default().to_tokens(tokens);
946 type_.to_tokens(tokens);
947 }
948}
949
950struct KarutinFn {
951 pub vis: Visibility,
952 pub sig: KarutinSignature,
953 pub block: Box<Block>,
954}
955
956impl Parse for KarutinFn {
957 fn parse(input: ParseStream) -> syn::Result<Self> {
958 Ok(Self {
959 vis: input.parse()?,
960 sig: input.parse()?,
961 block: input.parse()?,
962 })
963 }
964}
965
966impl ToTokens for KarutinFn {
967 fn to_tokens(&self, tokens: &mut TokenStream) {
968 self.vis.to_tokens(tokens);
969 self.sig.to_tokens(tokens);
970 self.block.to_tokens(tokens);
971 }
972}
973
974struct KarutinFnList {
975 inner: Vec<KarutinFn>,
976}
977
978impl KarutinFnList {
979 fn into_inner(self) -> Vec<KarutinFn> {
980 self.inner
981 }
982}
983
984impl Parse for KarutinFnList {
985 fn parse(input: ParseStream) -> syn::Result<Self> {
986 let mut inner = Vec::new();
987
988 while !input.is_empty() {
989 inner.push(input.parse()?);
990 }
991
992 Ok(Self { inner })
993 }
994}
995
996impl ToTokens for KarutinFnList {
997 fn to_tokens(&self, tokens: &mut TokenStream) {
998 for v in &self.inner {
999 v.to_tokens(tokens);
1000 }
1001 }
1002}
1003
1004struct KarutinSigList {
1005 inner: Vec<KarutinSignature>,
1006}
1007
1008impl KarutinSigList {
1009 fn into_inner(self) -> Vec<KarutinSignature> {
1010 self.inner
1011 }
1012}
1013
1014impl Parse for KarutinSigList {
1015 fn parse(input: ParseStream) -> syn::Result<Self> {
1016 let mut inner = Vec::new();
1017
1018 while !input.is_empty() {
1019 inner.push(input.parse()?);
1020 let _: Semi = input.parse()?;
1021 }
1022
1023 Ok(Self { inner })
1024 }
1025}
1026
1027impl ToTokens for KarutinSigList {
1028 fn to_tokens(&self, tokens: &mut TokenStream) {
1029 for v in &self.inner {
1030 v.to_tokens(tokens);
1031 Semi::default().to_tokens(tokens);
1032 }
1033 }
1034}
1035
1036enum Karutin {
1037 DefinitionList(KarutinFnList),
1038 DeclarationList(KarutinSigList),
1039}
1040
1041impl Parse for Karutin {
1042 fn parse(input: ParseStream) -> syn::Result<Self> {
1043 let mut errors: Option<Error> = None;
1044
1045 let mut combine = |e: Error| {
1046 let errors = &mut errors;
1047
1048 if let Some(errors) = errors {
1049 errors.combine(e);
1050 } else {
1051 let _ = errors.insert(e);
1052 }
1053 };
1054
1055 let fork = &input.fork();
1056 match KarutinFnList::parse(fork) {
1057 Ok(v) => {
1058 input.advance_to(fork);
1059 return Ok(Self::DefinitionList(v));
1060 },
1061 Err(e) => combine(e),
1062 }
1063
1064 let fork = &input.fork();
1065 match KarutinSigList::parse(fork) {
1066 Ok(v) => {
1067 input.advance_to(fork);
1068 return Ok(Self::DeclarationList(v));
1069 },
1070 Err(e) => combine(e),
1071 }
1072
1073 Err(errors.unwrap())
1074 }
1075}
1076
1077impl ToTokens for Karutin {
1078 fn to_tokens(&self, tokens: &mut TokenStream) {
1079 match self {
1080 Karutin::DefinitionList(karutin_fn_list) => karutin_fn_list.to_tokens(tokens),
1081 Karutin::DeclarationList(karutin_sig_list) => karutin_sig_list.to_tokens(tokens),
1082 }
1083 }
1084}
1085
1086fn wrap_completed_state(body: Box<Block>) -> TokenStream {
1087 parse_quote! {
1088 if self.states[#COMPLETED_STATE_ID] == 0 {
1089 #[allow(unreachable_code)]
1090 let _state = ::karutin::KarutinState::Returned( #body );
1091
1092 self.states[#COMPLETED_STATE_ID] = 1;
1093 _state
1094 } else {
1095 ::karutin::KarutinState::Completed
1096 }
1097 }
1098}
1099
1100fn zeroed_stack_locals(local_count: usize) -> TokenStream {
1101 let fields = create_stack_field_idents(local_count);
1102
1103 quote! {
1104 #(
1105 stack.#fields = ::karutin::internal::unchecked_zeroed();
1106 )*
1107 }
1108}
1109
1110fn handle_moved_stack(local_count: usize) -> TokenStream {
1111 let zsl = zeroed_stack_locals(local_count);
1112 let swap = quote! {
1113 unsafe {
1114 ::std::mem::swap(
1115 &mut *raw_stack_ptr,
1116 &mut stack_rep,
1117 );
1118 }
1119 };
1120
1121 quote! { #swap #zsl #swap }
1122}
1123
1124fn obtain_default_lifetime(ty: &mut Type) {
1125 match ty {
1126 Type::Array(type_array) => obtain_default_lifetime(type_array.elem.as_mut()),
1127 Type::Group(type_group) => obtain_default_lifetime(type_group.elem.as_mut()),
1128 Type::Paren(type_paren) => obtain_default_lifetime(type_paren.elem.as_mut()),
1129 Type::Slice(type_slice) => obtain_default_lifetime(type_slice.elem.as_mut()),
1130 Type::Path(_type_path) => {
1131 },
1133 Type::Tuple(type_tuple) => {
1134 for ty in type_tuple.elems.iter_mut() {
1135 obtain_default_lifetime(ty);
1136 }
1137 },
1138 Type::Reference(type_reference) if type_reference.lifetime.is_none() => {
1139 let lifetime = Lifetime::new(LIFETIME_STR, Span::call_site());
1140 type_reference.lifetime = Some(lifetime);
1141 },
1142 _ => {},
1143 }
1144}
1145
1146fn wrap_stack_management(
1147 stack_ident: &Ident,
1148 empty_generics: &TokenStream,
1149 local_count: usize,
1150 body: TokenStream,
1151) -> TokenStream {
1152 let hms = handle_moved_stack(local_count);
1153 let state_loop_label = Lifetime::new(STATE_LOOP_LABEL_STR, Span::call_site());
1154
1155 quote! {
1156 let mut stack;
1157 let mut stack_rep;
1158
1159 if let Some(stack_) = self.stack.as_ref() {
1160 (stack, stack_rep) = stack_.get_boxes::<#stack_ident<#empty_generics>>();
1161 } else {
1162 (stack, stack_rep) = ::karutin::internal::KarutinStack::create_zeroeds();
1163 let ret = ::karutin::internal::KarutinStack::from((stack, stack_rep));
1164 return ::karutin::internal::KarutinResponse::StackExpose(ret);
1165 }
1166
1167 let raw_stack_ptr = &mut stack as *mut Box<_>;
1168
1169 let ret = #state_loop_label: loop {
1170 break #body
1171 };
1172
1173 #hms
1174
1175 ::std::mem::forget(stack);
1176 ::std::mem::forget(stack_rep);
1177
1178 ::karutin::internal::KarutinResponse::StateLoop(ret)
1179 }
1180}
1181
1182fn check_karutin_fn(karutin_fn: &KarutinFn) -> Option<Error> {
1183 let checks = [
1184 check_blocks_macro_usage(karutin_fn),
1185 check_restriction_errors(karutin_fn),
1186 ];
1187
1188 let mut checks_result = Option::<Error>::None;
1189
1190 for check in checks.into_iter() {
1191 match (&mut checks_result, check) {
1192 (None, Some(err)) => {
1193 checks_result = Some(err);
1194 },
1195 (Some(base_err), Some(err)) => {
1196 base_err.combine(err);
1197 },
1198 _ => {},
1199 }
1200 }
1201
1202 checks_result
1203}
1204
1205fn karutin_stack(ident: &Ident, generics: &TokenStream, fields: &TokenStream) -> TokenStream {
1206 quote! {
1207 #[allow(non_camel_case_types)]
1208 struct #ident<#generics> {
1209 #fields
1210 }
1211 }
1212}
1213
1214fn karutin_ctx(ident: &Ident, lifetime: &Lifetime, state_count: usize) -> TokenStream {
1215 quote! {
1216 #[allow(non_camel_case_types)]
1217 #[derive(Default)]
1218 struct #ident<#lifetime> {
1219 stack: Option<::karutin::internal::KarutinStack<#lifetime>>,
1220 states: [usize; #state_count]
1221 }
1222 }
1223}
1224
1225fn karutin_resume_inner(
1226 ctx_ident: &Ident,
1227 lifetime: &Lifetime,
1228 generics: &Generics,
1229 params_pat: &TokenStream,
1230 params_ty: &Type,
1231 yield_type: &Type,
1232 return_type: &Type,
1233 body: &TokenStream,
1234) -> TokenStream {
1235 quote! {
1236 #[allow(unused_braces)]
1237 impl<#lifetime> #ctx_ident<#lifetime> {
1238 fn resume_inner #generics (&mut self, #params_pat: #params_ty)
1239 -> ::karutin::internal::KarutinResponse<#lifetime, #yield_type, #return_type>
1240 {
1241 #body
1242 }
1243 }
1244 }
1245}
1246
1247fn karutin_impl_debug(ctx_ident: &Ident, lifetime: &Lifetime, debug_name: &String) -> TokenStream {
1248 quote! {
1249 impl<#lifetime> std::fmt::Debug for #ctx_ident<#lifetime> {
1250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1251 f.debug_struct(#debug_name)
1252 .field("stack", self.stack.as_ref().unwrap())
1253 .field("states", &self.states)
1254 .finish()
1255 }
1256 }
1257 }
1258}
1259
1260fn karutin_impl_karutin(
1261 ctx_ident: &Ident,
1262 lifetime: &Lifetime,
1263 generics: &Generics,
1264 parameters_ty: &Type,
1265 yield_type: &Type,
1266 return_type: &Type,
1267) -> TokenStream {
1268 quote! {
1269 impl #generics ::karutin::Karutin<#parameters_ty> for #ctx_ident<#lifetime> {
1270 type Yield = #yield_type;
1271 type Return = #return_type;
1272
1273 #[inline(always)]
1274 fn resume(
1275 &mut self,
1276 args: #parameters_ty
1277 ) -> ::karutin::KarutinState<#yield_type, #return_type> {
1278 match self.resume_inner(args) {
1279 ::karutin::internal::KarutinResponse::StateLoop(v) => v,
1280 _ => { unreachable!() },
1281 }
1282 }
1283 }
1284 }
1285}
1286
1287fn karutin_signature(
1288 ident: &Ident,
1289 unsafety: &Option<Unsafe>,
1290 generics: &Generics,
1291 parameters_ty: &Type,
1292 yield_type: &Type,
1293 return_type: &Type,
1294) -> TokenStream {
1295 quote! {
1296 #unsafety fn #ident #generics () -> impl ::karutin::Karutin<
1297 #parameters_ty,
1298 Yield = #yield_type,
1299 Return = #return_type
1300 >
1301 }
1302}
1303
1304fn _karutin_fn(
1305 ctx_ident: &Ident,
1306 ident: &Ident,
1307 vis: &Visibility,
1308 unsafety: &Option<Unsafe>,
1309 generics: &Generics,
1310 parameters_ty: &Type,
1311 yield_type: &Type,
1312 return_type: &Type,
1313) -> TokenStream {
1314 let signature = karutin_signature(
1315 ident,
1316 unsafety,
1317 generics,
1318 parameters_ty,
1319 yield_type,
1320 return_type,
1321 );
1322
1323 quote! {
1324 #[inline]
1325 #vis #signature {
1326 let mut ctx = #ctx_ident::default();
1327
1328 let cold_start = ctx.resume_inner(
1329 ::karutin::internal::unchecked_zeroed()
1330 );
1331
1332 match cold_start {
1333 ::karutin::internal::KarutinResponse::StackExpose(v) => {
1334 ctx.stack = Some(v);
1335 },
1336 _ => { unreachable!() },
1337 }
1338
1339 ctx
1340 }
1341 }
1342}
1343
1344fn karutin_definition(mut karutin_fn: KarutinFn) -> TokenStream {
1345 if let Some(failed_check) = check_karutin_fn(&karutin_fn) {
1346 return failed_check.into_compile_error();
1347 }
1348
1349 transpile(&mut karutin_fn.block);
1350
1351 let local_count = build_stack(&mut karutin_fn.block);
1352 let state_count = sift_states(&mut karutin_fn.block);
1353
1354 let vis = karutin_fn.vis;
1355 let unsafety = karutin_fn.sig.unsafety;
1356 let ident = karutin_fn.sig.ident;
1357
1358 let ctx_ident = format_context_ident!(ident);
1359 let lifetime = Lifetime::new(LIFETIME_STR, Span::call_site());
1360
1361 let generics = karutin_fn.sig.generics;
1362 let mut combined_generics = generics.clone();
1363 let mut inner_generics = generics.clone();
1364
1365 let lifetime_param = LifetimeParam::new(lifetime.clone());
1366 let generic_param = GenericParam::Lifetime(lifetime_param);
1367
1368 combined_generics.params.insert(0, generic_param);
1369
1370 for generic_param in &mut inner_generics.params {
1371 match generic_param {
1372 syn::GenericParam::Lifetime(lifetime_param) => {
1373 lifetime_param.bounds.push(lifetime.clone());
1374 },
1375 _ => {},
1376 }
1377 }
1378
1379 let parameters = karutin_fn.sig.parameters.into_pat_type();
1380 let parameters_pat = parameters.pat.to_token_stream();
1381
1382 let mut parameters_ty = parameters.ty;
1383 let mut yield_type = karutin_fn.sig.output.yield_type;
1384 let mut return_type = karutin_fn.sig.output.return_type;
1385
1386 obtain_default_lifetime(parameters_ty.as_mut());
1387 obtain_default_lifetime(yield_type.as_mut());
1388 obtain_default_lifetime(return_type.as_mut());
1389
1390 let body = wrap_completed_state(karutin_fn.block).to_token_stream();
1391
1392 let stack_ident = format_stack_ident!(ident);
1393 let stack_generics = create_stack_generics(local_count);
1394 let empty_stack_generics = create_empty_stack_generics(local_count);
1395 let stack_fields = create_stack_fields(local_count);
1396
1397 let body2 = wrap_stack_management(&stack_ident, &empty_stack_generics, local_count, body);
1398
1399 let debug_name = format!("Karutin Context ({})", ident);
1400
1401 let stack_quote = karutin_stack(&stack_ident, &stack_generics, &stack_fields);
1402 let ctx_quote = karutin_ctx(&ctx_ident, &lifetime, state_count);
1403 let resume_inner_quote = karutin_resume_inner(
1404 &ctx_ident,
1405 &lifetime,
1406 &inner_generics,
1407 ¶meters_pat,
1408 ¶meters_ty,
1409 &yield_type,
1410 &return_type,
1411 &body2,
1412 );
1413 let impl_debug_quote = karutin_impl_debug(&ctx_ident, &lifetime, &debug_name);
1414 let impl_karutin = karutin_impl_karutin(
1415 &ctx_ident,
1416 &lifetime,
1417 &combined_generics,
1418 ¶meters_ty,
1419 &yield_type,
1420 &return_type,
1421 );
1422 let _fn = _karutin_fn(
1423 &ctx_ident,
1424 &ident,
1425 &vis,
1426 &unsafety,
1427 &combined_generics,
1428 ¶meters_ty,
1429 &yield_type,
1430 &return_type,
1431 );
1432
1433 quote! {
1434 #stack_quote
1435 #ctx_quote
1436 #resume_inner_quote
1437 #impl_debug_quote
1438 #impl_karutin
1439 #_fn
1440 }
1441}
1442
1443fn karutin_declaration(karutin_sig: KarutinSignature) -> TokenStream {
1444 let unsafety = karutin_sig.unsafety;
1445 let ident = karutin_sig.ident;
1446
1447 let mut combined_generics = karutin_sig.generics;
1448
1449 let lifetime = Lifetime::new(LIFETIME_STR, Span::call_site());
1450 let lifetime_param = LifetimeParam::new(lifetime.clone());
1451 let generic_param = GenericParam::Lifetime(lifetime_param);
1452 let parameters = karutin_sig.parameters.into_pat_type();
1453
1454 let mut parameters_ty = parameters.ty;
1455 let mut yield_type = karutin_sig.output.yield_type;
1456 let mut return_type = karutin_sig.output.return_type;
1457
1458 combined_generics.params.insert(0, generic_param);
1459
1460 obtain_default_lifetime(parameters_ty.as_mut());
1461 obtain_default_lifetime(yield_type.as_mut());
1462 obtain_default_lifetime(return_type.as_mut());
1463
1464 let signature = karutin_signature(
1465 &ident,
1466 &unsafety,
1467 &combined_generics,
1468 ¶meters_ty,
1469 &yield_type,
1470 &return_type,
1471 );
1472
1473 quote! {
1474 #signature;
1475 }
1476}
1477
1478type KarutinDslInput<'a> = &'a mut Peekable<proc_macro::token_stream::IntoIter>;
1479
1480fn karutin_dsl_yield_from(input: KarutinDslInput, output: &mut proc_macro::TokenStream) {
1481 let yield_ident = match input.next() {
1482 Some(proc_macro::TokenTree::Ident(i)) => i,
1483 _ => unreachable!(),
1484 };
1485
1486 let attr = Transpiler::create_yield_from_attr();
1487 let attr_stream2: TokenStream = quote_spanned! {yield_ident.span().into()=>#attr};
1488 let attr_stream: proc_macro::TokenStream = attr_stream2.into_token_stream().into();
1489
1490 output.extend(attr_stream);
1491 output.extend([proc_macro::TokenTree::Ident(yield_ident)]);
1492}
1493
1494fn karutin_dsl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1495 let mut output = proc_macro::TokenStream::new();
1496 let mut input = input.into_iter().peekable();
1497
1498 while let Some(tt) = input.next() {
1499 match tt {
1500 proc_macro::TokenTree::Punct(p) if p.as_char() == '~' => {
1501 if let Some(proc_macro::TokenTree::Ident(ident)) = input.peek() {
1502 if ident.to_string() == "yield" {
1503 karutin_dsl_yield_from(&mut input, &mut output);
1504 continue;
1505 }
1506 }
1507
1508 output.extend([p]);
1509 },
1510 proc_macro::TokenTree::Group(g) => {
1511 let del = g.delimiter();
1512 let stream = karutin_dsl(g.stream());
1513 let group = proc_macro::Group::new(del, stream);
1514
1515 output.extend([group]);
1516 },
1517 other => output.extend([other]),
1518 }
1519 }
1520
1521 output
1522}
1523
1524#[proc_macro]
1703pub fn karutin(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1704 let input = karutin_dsl(input);
1705 let parsed = parse_macro_input!(input as Karutin);
1706
1707 let mut stream = TokenStream::new();
1708
1709 match parsed {
1710 Karutin::DefinitionList(karutin_fn_list) => {
1711 for karutin_fn in karutin_fn_list.into_inner() {
1712 stream.extend(karutin_definition(karutin_fn));
1713 }
1714 },
1715 Karutin::DeclarationList(karutin_sig_list) => {
1716 for karutin_sig in karutin_sig_list.into_inner() {
1717 stream.extend(karutin_declaration(karutin_sig));
1718 }
1719 },
1720 }
1721
1722 stream.into_token_stream().into()
1723}
1724
1725#[proc_macro]
1736pub fn karutin_str(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1737 let stream = karutin(input);
1738
1739 let syntax_tree: syn::File = parse2(stream.into()).unwrap();
1740 let formatted = prettyplease::unparse(&syntax_tree);
1741
1742 let str: Expr = parse_quote! { #formatted };
1743
1744 str.into_token_stream().into()
1745}