use crate::generator_state::GeneratorStateInfo;
use crate::generator_yield_analysis::YieldAnalysis;
use crate::hir::{HirExpr, HirFunction, HirStmt, Type};
use crate::rust_gen::context::{CodeGenContext, ToRustExpr};
use crate::rust_gen::type_gen::rust_type_to_syn;
use anyhow::Result;
use quote::quote;
fn generate_state_fields(
state_info: &GeneratorStateInfo,
ctx: &mut CodeGenContext,
) -> Result<Vec<proc_macro2::TokenStream>> {
state_info
.state_variables
.iter()
.map(|var| {
let field_name = syn::Ident::new(&var.name, proc_macro2::Span::call_site());
let rust_type = ctx.type_mapper.map_type(&var.ty);
let field_type = rust_type_to_syn(&rust_type)?;
Ok(quote! { #field_name: #field_type })
})
.collect()
}
fn generate_param_fields(
func: &HirFunction,
state_info: &GeneratorStateInfo,
ctx: &mut CodeGenContext,
) -> Result<Vec<proc_macro2::TokenStream>> {
func.params
.iter()
.filter(|p| state_info.captured_params.contains(&p.name))
.map(|param| {
let field_name = syn::Ident::new(¶m.name, proc_macro2::Span::call_site());
let rust_type = ctx.type_mapper.map_type(¶m.ty);
let field_type = rust_type_to_syn(&rust_type)?;
Ok(quote! { #field_name: #field_type })
})
.collect()
}
#[inline]
fn extract_generator_item_type(
func: &HirFunction,
ctx: &CodeGenContext,
) -> Result<syn::Type> {
let rust_yield_type = ctx.type_mapper.map_type(&func.ret_type);
rust_type_to_syn(&rust_yield_type)
}
fn generate_state_initializers(state_info: &GeneratorStateInfo) -> Vec<proc_macro2::TokenStream> {
state_info
.state_variables
.iter()
.map(|var| {
let field_name = syn::Ident::new(&var.name, proc_macro2::Span::call_site());
let default_value = get_default_value_for_type(&var.ty);
quote! { #field_name: #default_value }
})
.collect()
}
fn generate_param_initializers(
func: &HirFunction,
state_info: &GeneratorStateInfo,
) -> Vec<proc_macro2::TokenStream> {
func.params
.iter()
.filter(|p| state_info.captured_params.contains(&p.name))
.map(|param| {
let field_name = syn::Ident::new(¶m.name, proc_macro2::Span::call_site());
quote! { #field_name: #field_name }
})
.collect()
}
#[inline]
fn default_int() -> proc_macro2::TokenStream {
quote! { 0 }
}
#[inline]
fn default_float() -> proc_macro2::TokenStream {
quote! { 0.0 }
}
#[inline]
fn default_bool() -> proc_macro2::TokenStream {
quote! { false }
}
#[inline]
fn default_string() -> proc_macro2::TokenStream {
quote! { String::new() }
}
#[inline]
fn default_generic() -> proc_macro2::TokenStream {
quote! { Default::default() }
}
#[inline]
fn get_default_value_for_type(ty: &Type) -> proc_macro2::TokenStream {
match ty {
Type::Int => default_int(),
Type::Float => default_float(),
Type::Bool => default_bool(),
Type::String => default_string(),
_ => default_generic(),
}
}
#[inline]
fn generate_state_struct_name(name: &syn::Ident) -> syn::Ident {
let name_str = name.to_string();
let pascal_case = name_str
.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
})
.collect::<String>();
let state_struct_name = format!("{}State", pascal_case);
syn::Ident::new(&state_struct_name, name.span())
}
#[inline]
fn populate_generator_state_vars(ctx: &mut CodeGenContext, state_info: &GeneratorStateInfo) {
ctx.generator_state_vars.clear();
for var in &state_info.state_variables {
ctx.generator_state_vars.insert(var.name.clone());
}
for param in &state_info.captured_params {
ctx.generator_state_vars.insert(param.clone());
}
}
#[inline]
fn generate_generator_body(
func: &HirFunction,
ctx: &mut CodeGenContext,
) -> Result<Vec<proc_macro2::TokenStream>> {
use crate::rust_gen::RustCodeGen;
ctx.in_generator = true;
let generator_body_stmts: Vec<_> = func
.body
.iter()
.map(|stmt| stmt.to_rust_tokens(ctx))
.collect::<Result<Vec<_>>>()?;
ctx.in_generator = false;
ctx.generator_state_vars.clear();
Ok(generator_body_stmts)
}
#[inline]
fn hir_expr_to_syn(expr: &HirExpr, ctx: &mut CodeGenContext) -> Result<syn::Expr> {
expr.to_rust_expr(ctx)
}
#[inline]
fn generate_simple_multi_state_match(
yield_analysis: &YieldAnalysis,
_func: &HirFunction,
ctx: &mut CodeGenContext,
) -> Result<proc_macro2::TokenStream> {
let mut match_arms = Vec::new();
if let Some(first_yield) = yield_analysis.yield_points.first() {
let yield_value = hir_expr_to_syn(&first_yield.yield_expr, ctx)?;
let next_state = first_yield.state_id;
match_arms.push(quote! {
0 => {
self.state = #next_state;
return Some(#yield_value);
}
});
}
for (idx, yield_point) in yield_analysis.yield_points.iter().enumerate() {
let current_state = yield_point.state_id;
if let Some(next_yield) = yield_analysis.yield_points.get(idx + 1) {
let yield_value = hir_expr_to_syn(&next_yield.yield_expr, ctx)?;
let next_state = next_yield.state_id;
match_arms.push(quote! {
#current_state => {
self.state = #next_state;
return Some(#yield_value);
}
});
} else {
match_arms.push(quote! {
#current_state => {
self.state = #current_state + 1;
None
}
});
}
}
match_arms.push(quote! {
_ => None
});
Ok(quote! {
match self.state {
#(#match_arms)*
}
})
}
#[inline]
fn generate_simple_loop_with_yield(
func: &HirFunction,
yield_analysis: &YieldAnalysis,
ctx: &mut CodeGenContext,
) -> Result<proc_macro2::TokenStream> {
let loop_info = extract_loop_info(func)?;
let init_stmts = generate_loop_init_stmts(&loop_info.pre_loop_stmts, ctx)?;
let yield_point = &yield_analysis.yield_points[0];
let yield_value = hir_expr_to_syn(&yield_point.yield_expr, ctx)?;
let loop_condition = hir_expr_to_syn(&loop_info.condition, ctx)?;
let loop_body_stmts = generate_loop_body_stmts(&loop_info.body_stmts, ctx)?;
Ok(quote! {
match self.state {
0 => {
#(#init_stmts)*
self.state = 1;
self.next()
}
1 => {
if #loop_condition {
let result = #yield_value;
#(#loop_body_stmts)*
return Some(result);
} else {
self.state = 2;
None
}
}
_ => None
}
})
}
#[inline]
fn extract_loop_info(func: &HirFunction) -> Result<LoopInfo> {
let mut pre_loop_stmts = Vec::new();
let mut loop_stmt = None;
for stmt in &func.body {
match stmt {
HirStmt::While { condition, body } => {
loop_stmt = Some((condition.clone(), body.clone()));
break;
}
_ => {
pre_loop_stmts.push(stmt.clone());
}
}
}
let (condition, body) = loop_stmt
.ok_or_else(|| anyhow::anyhow!("No while loop found in generator function"))?;
let mut body_stmts = Vec::new();
for stmt in &body {
if !matches!(stmt, HirStmt::Expr(HirExpr::Yield { .. })) {
body_stmts.push(stmt.clone());
}
}
Ok(LoopInfo {
pre_loop_stmts,
condition,
body_stmts,
})
}
struct LoopInfo {
pre_loop_stmts: Vec<HirStmt>,
condition: HirExpr,
body_stmts: Vec<HirStmt>,
}
#[inline]
fn generate_loop_init_stmts(
stmts: &[HirStmt],
ctx: &mut CodeGenContext,
) -> Result<Vec<proc_macro2::TokenStream>> {
use crate::rust_gen::RustCodeGen;
stmts.iter().map(|stmt| stmt.to_rust_tokens(ctx)).collect()
}
#[inline]
fn generate_loop_body_stmts(
stmts: &[HirStmt],
ctx: &mut CodeGenContext,
) -> Result<Vec<proc_macro2::TokenStream>> {
use crate::rust_gen::RustCodeGen;
stmts.iter().map(|stmt| stmt.to_rust_tokens(ctx)).collect()
}
#[inline]
#[allow(clippy::too_many_arguments)] pub fn codegen_generator_function(
func: &HirFunction,
name: &syn::Ident,
generic_params: &proc_macro2::TokenStream,
where_clause: &proc_macro2::TokenStream,
params: &[proc_macro2::TokenStream],
attrs: &[proc_macro2::TokenStream],
_rust_ret_type: &crate::type_mapper::RustType,
ctx: &mut CodeGenContext,
) -> Result<proc_macro2::TokenStream> {
let state_info = GeneratorStateInfo::analyze(func);
let yield_analysis = YieldAnalysis::analyze(func);
let use_simple_multi_state = yield_analysis.has_yields()
&& yield_analysis.yield_points.iter().all(|yp| yp.depth == 0);
let state_ident = generate_state_struct_name(name);
let state_fields = generate_state_fields(&state_info, ctx)?;
let param_fields = generate_param_fields(func, &state_info, ctx)?;
let all_fields = [state_fields, param_fields].concat();
let state_inits = generate_state_initializers(&state_info);
let param_inits = generate_param_initializers(func, &state_info);
let all_inits = [state_inits, param_inits].concat();
let state_machine_field = quote! {
state: usize
};
let item_type = extract_generator_item_type(func, ctx)?;
populate_generator_state_vars(ctx, &state_info);
let has_while_loop = func.body.iter().any(|stmt| matches!(stmt, HirStmt::While { .. }));
let has_loop_yields = yield_analysis.has_yields()
&& yield_analysis.yield_points.iter().any(|yp| yp.depth > 0);
let state_machine_impl = if use_simple_multi_state {
generate_simple_multi_state_match(&yield_analysis, func, ctx)?
} else if has_while_loop && has_loop_yields && yield_analysis.yield_points.len() == 1 {
generate_simple_loop_with_yield(func, &yield_analysis, ctx)?
} else {
let generator_body_stmts = generate_generator_body(func, ctx)?;
quote! {
match self.state {
0 => {
self.state = 1;
#(#generator_body_stmts)*
None
}
_ => None
}
}
};
Ok(quote! {
#(#attrs)*
#[doc = " Generator state struct"]
#[derive(Debug)]
struct #state_ident {
#state_machine_field,
#(#all_fields),*
}
#[doc = " Generator function - returns Iterator"]
pub fn #name #generic_params(#(#params),*) -> impl Iterator<Item = #item_type> #where_clause {
#state_ident {
state: 0,
#(#all_inits),*
}
}
impl Iterator for #state_ident {
type Item = #item_type;
fn next(&mut self) -> Option<Self::Item> {
#state_machine_impl
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(non_snake_case)]
fn test_DEPYLER_0259_snake_case_to_pascal_case_naming() {
let input_name = syn::Ident::new("count_up", proc_macro2::Span::call_site());
let result = generate_state_struct_name(&input_name);
assert_eq!(
result.to_string(),
"CountUpState",
"DEPYLER-0259: Should convert snake_case to PascalCase, not just capitalize first char"
);
}
#[test]
#[allow(non_snake_case)]
fn test_DEPYLER_0259_single_word_naming() {
let input_name = syn::Ident::new("counter", proc_macro2::Span::call_site());
let result = generate_state_struct_name(&input_name);
assert_eq!(result.to_string(), "CounterState");
}
#[test]
#[allow(non_snake_case)]
fn test_DEPYLER_0259_multiple_words_naming() {
let input_name = syn::Ident::new("fibonacci_generator_with_memo", proc_macro2::Span::call_site());
let result = generate_state_struct_name(&input_name);
assert_eq!(result.to_string(), "FibonacciGeneratorWithMemoState");
}
}