#![allow(clippy::doc_markdown)]
use super::Transpiler;
use crate::frontend::ast::{Expr, ExprKind, Param, TypeKind};
use anyhow::Result;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashSet;
impl Transpiler {
pub(crate) fn infer_return_type_from_params_impl(
&self,
body: &Expr,
params: &[Param],
) -> Result<Option<TokenStream>> {
super::return_type_helpers::infer_return_type_from_params(body, params, |ty| {
self.transpile_type(ty)
})
}
pub(crate) fn infer_param_type_impl(
&self,
param: &Param,
body: &Expr,
func_name: &str,
) -> TokenStream {
self.infer_param_type_with_index(param, body, func_name, None)
}
pub(crate) fn infer_param_type_with_index(
&self,
param: &Param,
body: &Expr,
func_name: &str,
param_index: Option<usize>,
) -> TokenStream {
use super::type_inference::{
infer_param_type_from_builtin_usage, is_param_used_as_array, is_param_used_as_bool,
is_param_used_as_function, is_param_used_as_index, is_param_used_in_print_macro,
is_param_used_in_string_concat, is_param_used_numerically, is_param_used_with_len,
};
if let Some(idx) = param_index {
if let Some(call_site_type) = self.get_call_site_param_type(func_name, idx) {
if call_site_type != "_" {
match call_site_type.as_str() {
"f64" => return quote! { f64 },
"f32" => return quote! { f32 },
"i64" => return quote! { i64 },
"i32" => return quote! { i32 },
"String" => return quote! { String },
"bool" => return quote! { bool },
t if t.starts_with("Vec<") => {
let inner = &t[4..t.len() - 1];
let inner_ident = format_ident!("{}", inner);
return quote! { Vec<#inner_ident> };
}
_ => {} }
}
}
}
if is_param_used_as_function(¶m.name(), body) {
return quote! { impl Fn(i32) -> i32 };
}
if is_param_used_as_bool(¶m.name(), body) {
return quote! { bool };
}
if is_param_used_as_array(¶m.name(), body) {
if self.is_nested_array_param_impl(¶m.name(), body) {
return quote! { &Vec<Vec<i32>> };
}
return quote! { &Vec<i32> };
}
if is_param_used_with_len(¶m.name(), body) {
if self.is_nested_array_param_impl(¶m.name(), body) {
return quote! { &Vec<Vec<i32>> };
}
return quote! { &Vec<i32> };
}
if is_param_used_as_index(¶m.name(), body) {
return quote! { i32 };
}
if is_param_used_numerically(¶m.name(), body)
|| super::function_analysis::looks_like_numeric_function(func_name)
{
return quote! { i32 };
}
if let Some(type_hint) = infer_param_type_from_builtin_usage(¶m.name(), body) {
if type_hint == "&str" {
return quote! { &str };
}
}
if is_param_used_in_string_concat(¶m.name(), body) {
return quote! { &str };
}
if is_param_used_in_print_macro(¶m.name(), body) {
return quote! { &str };
}
quote! { i32 }
}
pub(crate) fn is_nested_array_param_impl(&self, param_name: &str, expr: &Expr) -> bool {
Self::find_nested_array_access_impl(param_name, expr, &mut HashSet::new())
}
fn find_nested_array_access_impl(
param_name: &str,
expr: &Expr,
visited: &mut HashSet<usize>,
) -> bool {
let expr_addr = std::ptr::from_ref(expr) as usize;
if visited.contains(&expr_addr) {
return false;
}
visited.insert(expr_addr);
match &expr.kind {
ExprKind::IndexAccess { object, .. } => {
if let ExprKind::IndexAccess { object: inner, .. } = &object.kind {
if let ExprKind::Identifier(name) = &inner.kind {
if name == param_name {
return true;
}
}
}
Self::find_nested_array_access_impl(param_name, object, visited)
}
ExprKind::Block(exprs) => exprs
.iter()
.any(|e| Self::find_nested_array_access_impl(param_name, e, visited)),
ExprKind::Let { value, body, .. } | ExprKind::LetPattern { value, body, .. } => {
Self::find_nested_array_access_impl(param_name, value, visited)
|| Self::find_nested_array_access_impl(param_name, body, visited)
}
ExprKind::Binary { left, right, .. } => {
Self::find_nested_array_access_impl(param_name, left, visited)
|| Self::find_nested_array_access_impl(param_name, right, visited)
}
ExprKind::While {
condition, body, ..
} => {
Self::find_nested_array_access_impl(param_name, condition, visited)
|| Self::find_nested_array_access_impl(param_name, body, visited)
}
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
Self::find_nested_array_access_impl(param_name, condition, visited)
|| Self::find_nested_array_access_impl(param_name, then_branch, visited)
|| else_branch.as_ref().is_some_and(|e| {
Self::find_nested_array_access_impl(param_name, e, visited)
})
}
ExprKind::Assign { target, value } | ExprKind::CompoundAssign { target, value, .. } => {
Self::find_nested_array_access_impl(param_name, target, visited)
|| Self::find_nested_array_access_impl(param_name, value, visited)
}
_ => false,
}
}
pub(crate) fn generate_param_tokens_impl(
&self,
params: &[Param],
body: &Expr,
func_name: &str,
) -> Result<Vec<TokenStream>> {
params
.iter()
.enumerate()
.map(|(idx, p)| {
let param_name = format_ident!("{}", p.name());
if p.name() == "self" {
if let TypeKind::Reference { is_mut, .. } = &p.ty.kind {
if *is_mut {
return Ok(quote! { &mut self });
}
return Ok(quote! { &self });
}
return Ok(quote! { self });
}
let type_tokens = if let Ok(tokens) = self.transpile_type(&p.ty) {
let token_str = tokens.to_string();
if token_str == "_" {
self.infer_param_type_with_index(p, body, func_name, Some(idx))
} else {
tokens
}
} else {
self.infer_param_type_with_index(p, body, func_name, Some(idx))
};
if p.is_mutable {
Ok(quote! { mut #param_name: #type_tokens })
} else {
Ok(quote! { #param_name: #type_tokens })
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frontend::ast::{Literal, Span, Type};
fn make_transpiler() -> Transpiler {
Transpiler::new()
}
fn make_expr(kind: ExprKind) -> Expr {
Expr {
kind,
span: Span::default(),
attributes: vec![],
leading_comments: vec![],
trailing_comment: None,
}
}
fn int_expr(n: i64) -> Expr {
make_expr(ExprKind::Literal(Literal::Integer(n, None)))
}
fn ident_expr(name: &str) -> Expr {
make_expr(ExprKind::Identifier(name.to_string()))
}
fn make_param(name: &str) -> Param {
Param {
pattern: crate::frontend::ast::Pattern::Identifier(name.to_string()),
ty: Type {
kind: crate::frontend::ast::TypeKind::Named("_".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
}
}
fn make_mut_param(name: &str) -> Param {
Param {
pattern: crate::frontend::ast::Pattern::Identifier(name.to_string()),
ty: Type {
kind: crate::frontend::ast::TypeKind::Named("_".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: true,
default_value: None,
}
}
#[test]
fn test_infer_param_type_numeric_function() {
let transpiler = make_transpiler();
let param = make_param("x");
let body = int_expr(42);
let result = transpiler.infer_param_type_impl(¶m, &body, "add");
assert_eq!(result.to_string(), "i32");
}
#[test]
fn test_infer_param_type_bool_condition() {
let transpiler = make_transpiler();
let param = make_param("flag");
let body = make_expr(ExprKind::If {
condition: Box::new(ident_expr("flag")),
then_branch: Box::new(int_expr(1)),
else_branch: Some(Box::new(int_expr(0))),
});
let result = transpiler.infer_param_type_impl(¶m, &body, "check");
assert_eq!(result.to_string(), "bool");
}
#[test]
fn test_infer_param_type_array_indexing() {
let transpiler = make_transpiler();
let param = make_param("arr");
let body = make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("arr")),
index: Box::new(int_expr(0)),
});
let result = transpiler.infer_param_type_impl(¶m, &body, "get_first");
assert_eq!(result.to_string(), "& Vec < i32 >");
}
#[test]
fn test_infer_param_type_nested_array() {
let transpiler = make_transpiler();
let param = make_param("matrix");
let body = make_expr(ExprKind::IndexAccess {
object: Box::new(make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("matrix")),
index: Box::new(int_expr(0)),
})),
index: Box::new(int_expr(1)),
});
let result = transpiler.infer_param_type_impl(¶m, &body, "get_element");
let result_str = result.to_string();
assert!(result_str.contains("Vec") && result_str.contains("i32"));
assert!(result_str.starts_with("&"));
}
#[test]
fn test_infer_param_type_index_usage() {
let transpiler = make_transpiler();
let param = make_param("i");
let body = make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("arr")),
index: Box::new(ident_expr("i")),
});
let result = transpiler.infer_param_type_impl(¶m, &body, "access");
assert_eq!(result.to_string(), "i32");
}
#[test]
fn test_infer_param_type_default() {
let transpiler = make_transpiler();
let param = make_param("unused");
let body = int_expr(42);
let result = transpiler.infer_param_type_impl(¶m, &body, "foo");
assert_eq!(result.to_string(), "i32");
}
#[test]
fn test_is_nested_array_simple() {
let transpiler = make_transpiler();
let body = make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("arr")),
index: Box::new(int_expr(0)),
});
assert!(!transpiler.is_nested_array_param_impl("arr", &body));
}
#[test]
fn test_is_nested_array_2d() {
let transpiler = make_transpiler();
let body = make_expr(ExprKind::IndexAccess {
object: Box::new(make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("matrix")),
index: Box::new(int_expr(0)),
})),
index: Box::new(int_expr(1)),
});
assert!(transpiler.is_nested_array_param_impl("matrix", &body));
}
#[test]
fn test_is_nested_array_in_block() {
let transpiler = make_transpiler();
let body = make_expr(ExprKind::Block(vec![make_expr(ExprKind::IndexAccess {
object: Box::new(make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("m")),
index: Box::new(ident_expr("i")),
})),
index: Box::new(ident_expr("j")),
})]));
assert!(transpiler.is_nested_array_param_impl("m", &body));
}
#[test]
fn test_is_nested_array_in_if() {
let transpiler = make_transpiler();
let nested_access = make_expr(ExprKind::IndexAccess {
object: Box::new(make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("grid")),
index: Box::new(int_expr(0)),
})),
index: Box::new(int_expr(0)),
});
let body = make_expr(ExprKind::If {
condition: Box::new(make_expr(ExprKind::Literal(Literal::Bool(true)))),
then_branch: Box::new(nested_access),
else_branch: None,
});
assert!(transpiler.is_nested_array_param_impl("grid", &body));
}
#[test]
fn test_is_nested_array_different_param() {
let transpiler = make_transpiler();
let body = make_expr(ExprKind::IndexAccess {
object: Box::new(make_expr(ExprKind::IndexAccess {
object: Box::new(ident_expr("other")),
index: Box::new(int_expr(0)),
})),
index: Box::new(int_expr(1)),
});
assert!(!transpiler.is_nested_array_param_impl("matrix", &body));
}
#[test]
fn test_generate_param_tokens_simple() {
let transpiler = make_transpiler();
let params = vec![make_param("x")];
let body = int_expr(42);
let result = transpiler
.generate_param_tokens_impl(¶ms, &body, "add")
.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].to_string().contains("x"));
}
#[test]
fn test_generate_param_tokens_mutable() {
let transpiler = make_transpiler();
let params = vec![make_mut_param("count")];
let body = int_expr(0);
let result = transpiler
.generate_param_tokens_impl(¶ms, &body, "increment")
.unwrap();
assert!(result[0].to_string().contains("mut"));
assert!(result[0].to_string().contains("count"));
}
#[test]
fn test_generate_param_tokens_multiple() {
let transpiler = make_transpiler();
let params = vec![make_param("a"), make_param("b")];
let body = int_expr(0);
let result = transpiler
.generate_param_tokens_impl(¶ms, &body, "add")
.unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_generate_param_tokens_self_ref() {
let transpiler = make_transpiler();
let mut param = make_param("self");
param.ty = Type {
kind: TypeKind::Reference {
is_mut: false,
lifetime: None,
inner: Box::new(Type {
kind: crate::frontend::ast::TypeKind::Named("Self".to_string()),
span: Span::default(),
}),
},
span: Span::default(),
};
let params = vec![param];
let body = int_expr(0);
let result = transpiler
.generate_param_tokens_impl(¶ms, &body, "method")
.unwrap();
assert_eq!(result[0].to_string(), "& self");
}
#[test]
fn test_generate_param_tokens_self_mut() {
let transpiler = make_transpiler();
let mut param = make_param("self");
param.ty = Type {
kind: TypeKind::Reference {
is_mut: true,
lifetime: None,
inner: Box::new(Type {
kind: crate::frontend::ast::TypeKind::Named("Self".to_string()),
span: Span::default(),
}),
},
span: Span::default(),
};
let params = vec![param];
let body = int_expr(0);
let result = transpiler
.generate_param_tokens_impl(¶ms, &body, "method")
.unwrap();
assert_eq!(result[0].to_string(), "& mut self");
}
#[test]
fn test_generate_param_tokens_self_owned() {
let transpiler = make_transpiler();
let param = make_param("self");
let params = vec![param];
let body = int_expr(0);
let result = transpiler
.generate_param_tokens_impl(¶ms, &body, "consume")
.unwrap();
assert_eq!(result[0].to_string(), "self");
}
#[test]
fn test_infer_return_type_no_params() {
let transpiler = make_transpiler();
let body = int_expr(42);
let result = transpiler
.infer_return_type_from_params_impl(&body, &[])
.unwrap();
assert!(result.is_none());
}
#[test]
fn test_infer_return_type_with_params() {
let transpiler = make_transpiler();
let params = vec![make_param("x")];
let body = ident_expr("x");
let result = transpiler
.infer_return_type_from_params_impl(&body, ¶ms)
.unwrap();
assert!(result.is_none() || result.is_some());
}
}