use crate::hir::{HirExpr, Literal};
use crate::rust_gen::context::{CodeGenContext, ToRustExpr};
use anyhow::{bail, Result};
use syn::parse_quote;
pub fn convert_range_call(args: &[syn::Expr]) -> Result<syn::Expr> {
match args.len() {
1 => {
let end = &args[0];
Ok(parse_quote! { 0..(#end) })
}
2 => {
let start = &args[0];
let end = &args[1];
Ok(parse_quote! { (#start)..(#end) })
}
3 => convert_range_with_step(&args[0], &args[1], &args[2]),
_ => bail!("Invalid number of arguments for range()"),
}
}
pub fn convert_range_with_step(
start: &syn::Expr,
end: &syn::Expr,
step: &syn::Expr,
) -> Result<syn::Expr> {
let is_negative_step =
matches!(step, syn::Expr::Unary(unary) if matches!(unary.op, syn::UnOp::Neg(_)));
if is_negative_step {
convert_range_negative_step(start, end, step)
} else {
convert_range_positive_step(start, end, step)
}
}
pub fn convert_range_negative_step(
start: &syn::Expr,
end: &syn::Expr,
step: &syn::Expr,
) -> Result<syn::Expr> {
Ok(parse_quote! {
{
let step = (#step as i32).abs() as usize;
if step == 0 {
panic!("range() arg 3 must not be zero");
}
(#end..#start).rev().step_by(step.max(1))
}
})
}
pub fn convert_range_positive_step(
start: &syn::Expr,
end: &syn::Expr,
step: &syn::Expr,
) -> Result<syn::Expr> {
Ok(parse_quote! {
{
let step = #step as usize;
if step == 0 {
panic!("range() arg 3 must not be zero");
}
(#start..#end).step_by(step)
}
})
}
pub fn convert_array_init_call(
ctx: &mut CodeGenContext,
func: &str,
args: &[HirExpr],
_arg_exprs: &[syn::Expr],
) -> Result<syn::Expr> {
if args.is_empty() {
bail!("{} requires at least one argument", func);
}
if let HirExpr::Literal(Literal::Int(size)) = &args[0] {
if *size > 0 && *size <= 32 {
convert_array_small_literal(ctx, func, args, *size)
} else {
convert_array_large_literal(ctx, func, args)
}
} else {
convert_array_dynamic_size(ctx, func, args)
}
}
pub fn convert_array_small_literal(
ctx: &mut CodeGenContext,
func: &str,
args: &[HirExpr],
size: i64,
) -> Result<syn::Expr> {
let size_lit = syn::LitInt::new(&size.to_string(), proc_macro2::Span::call_site());
match func {
"zeros" => Ok(parse_quote! { vec![0i32; #size_lit] }),
"ones" => Ok(parse_quote! { vec![1i32; #size_lit] }),
"full" => {
if args.len() >= 2 {
let value = args[1].to_rust_expr(ctx)?;
Ok(parse_quote! { vec![#value; #size_lit] })
} else {
bail!("full() requires a value argument");
}
}
_ => unreachable!(),
}
}
pub fn convert_array_large_literal(
ctx: &mut CodeGenContext,
func: &str,
args: &[HirExpr],
) -> Result<syn::Expr> {
let size_expr = args[0].to_rust_expr(ctx)?;
match func {
"zeros" => Ok(parse_quote! { vec![0i32; #size_expr as usize] }),
"ones" => Ok(parse_quote! { vec![1i32; #size_expr as usize] }),
"full" => {
if args.len() >= 2 {
let value = args[1].to_rust_expr(ctx)?;
Ok(parse_quote! { vec![#value; #size_expr as usize] })
} else {
bail!("full() requires a value argument");
}
}
_ => unreachable!(),
}
}
pub fn convert_array_dynamic_size(
ctx: &mut CodeGenContext,
func: &str,
args: &[HirExpr],
) -> Result<syn::Expr> {
let size_expr = args[0].to_rust_expr(ctx)?;
match func {
"zeros" => Ok(parse_quote! { vec![0i32; #size_expr as usize] }),
"ones" => Ok(parse_quote! { vec![1i32; #size_expr as usize] }),
"full" => {
if args.len() >= 2 {
let value = args[1].to_rust_expr(ctx)?;
Ok(parse_quote! { vec![#value; #size_expr as usize] })
} else {
bail!("full() requires a value argument");
}
}
_ => unreachable!(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_range_single_arg() {
let args: Vec<syn::Expr> = vec![parse_quote! { 5 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("0"));
assert!(result_str.contains("5"));
}
#[test]
fn test_range_two_args() {
let args: Vec<syn::Expr> = vec![parse_quote! { 2 }, parse_quote! { 7 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("2"));
assert!(result_str.contains("7"));
}
#[test]
fn test_range_with_positive_step() {
let args: Vec<syn::Expr> =
vec![parse_quote! { 0 }, parse_quote! { 10 }, parse_quote! { 2 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("step_by") || result_str.contains("step"));
}
#[test]
fn test_range_with_negative_step() {
let args: Vec<syn::Expr> =
vec![parse_quote! { 10 }, parse_quote! { 0 }, parse_quote! { -1 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("rev"));
}
#[test]
fn test_range_invalid_args() {
let args: Vec<syn::Expr> = vec![];
assert!(convert_range_call(&args).is_err());
let too_many: Vec<syn::Expr> = vec![
parse_quote! { 0 },
parse_quote! { 10 },
parse_quote! { 2 },
parse_quote! { 3 },
];
assert!(convert_range_call(&too_many).is_err());
}
#[test]
fn test_range_positive_step_direct() {
let start: syn::Expr = parse_quote! { 0 };
let end: syn::Expr = parse_quote! { 10 };
let step: syn::Expr = parse_quote! { 2 };
let result = convert_range_positive_step(&start, &end, &step).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("step_by"));
}
#[test]
fn test_range_negative_step_direct() {
let start: syn::Expr = parse_quote! { 10 };
let end: syn::Expr = parse_quote! { 0 };
let step: syn::Expr = parse_quote! { -1 };
let result = convert_range_negative_step(&start, &end, &step).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("rev"));
assert!(result_str.contains("abs"));
}
#[test]
fn test_convert_range_with_step_dispatch() {
let start: syn::Expr = parse_quote! { 0 };
let end: syn::Expr = parse_quote! { 10 };
let step_pos: syn::Expr = parse_quote! { 2 };
let result_pos = convert_range_with_step(&start, &end, &step_pos).unwrap();
let pos_str = quote::quote!(#result_pos).to_string();
assert!(!pos_str.contains("rev"));
let step_neg: syn::Expr = parse_quote! { -2 };
let result_neg = convert_range_with_step(&start, &end, &step_neg).unwrap();
let neg_str = quote::quote!(#result_neg).to_string();
assert!(neg_str.contains("rev"));
}
#[test]
fn test_range_complex_end() {
let args: Vec<syn::Expr> = vec![parse_quote! { n + 1 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("0"));
}
#[test]
fn test_range_variable_bounds() {
let args: Vec<syn::Expr> = vec![parse_quote! { start }, parse_quote! { end }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("start"));
assert!(result_str.contains("end"));
}
#[test]
fn test_range_function_call_bound() {
let args: Vec<syn::Expr> = vec![parse_quote! { len(items) }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("len"));
}
#[test]
fn test_range_binary_step() {
let args: Vec<syn::Expr> = vec![
parse_quote! { 0 },
parse_quote! { 100 },
parse_quote! { n * 2 },
];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("step_by") || result_str.contains("step"));
}
#[test]
fn test_range_step_zero_protection() {
let start: syn::Expr = parse_quote! { 0 };
let end: syn::Expr = parse_quote! { 10 };
let step: syn::Expr = parse_quote! { step };
let result = convert_range_positive_step(&start, &end, &step).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("== 0") || result_str.contains("panic"));
}
#[test]
fn test_range_negative_step_abs() {
let start: syn::Expr = parse_quote! { 10 };
let end: syn::Expr = parse_quote! { 0 };
let step: syn::Expr = parse_quote! { -2 };
let result = convert_range_negative_step(&start, &end, &step).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("abs"));
}
#[test]
fn test_range_literal_zero() {
let args: Vec<syn::Expr> = vec![parse_quote! { 0 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("0"));
}
#[test]
fn test_range_negative_start() {
let args: Vec<syn::Expr> = vec![parse_quote! { -5 }, parse_quote! { 5 }];
let result = convert_range_call(&args).unwrap();
let result_str = quote::quote!(#result).to_string();
assert!(result_str.contains("-"));
assert!(result_str.contains("5"));
}
}