use crate::hir::{AssignTarget, HirExpr, HirFunction, HirStmt, Type};
use std::collections::HashSet;
#[derive(Debug, Default)]
pub struct BorrowingContext {
mutated_params: HashSet<String>,
escaping_params: HashSet<String>,
read_only_params: HashSet<String>,
loop_used_params: HashSet<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum BorrowingPattern {
Owned,
Borrowed,
MutableBorrow,
}
impl BorrowingContext {
pub fn new() -> Self {
Self::default()
}
pub fn analyze_function(&mut self, func: &HirFunction) {
for (param_name, _) in &func.params {
self.read_only_params.insert(param_name.clone());
}
for stmt in &func.body {
self.analyze_stmt(stmt);
}
for param in &self.mutated_params {
self.read_only_params.remove(param);
}
for param in &self.escaping_params {
self.read_only_params.remove(param);
}
}
pub fn get_pattern(&self, param_name: &str, param_type: &Type) -> BorrowingPattern {
if self.escaping_params.contains(param_name) {
BorrowingPattern::Owned
} else if self.mutated_params.contains(param_name) {
BorrowingPattern::MutableBorrow
} else if self.is_copyable(param_type) {
BorrowingPattern::Owned
} else {
BorrowingPattern::Borrowed
}
}
pub fn generate_param_signature(&self, param_name: &str, param_type: &Type) -> String {
let pattern = self.get_pattern(param_name, param_type);
let type_str = self.type_to_rust_string(param_type);
match pattern {
BorrowingPattern::Owned => format!("{}: {}", param_name, type_str),
BorrowingPattern::Borrowed => format!("{}: &{}", param_name, type_str),
BorrowingPattern::MutableBorrow => format!("{}: &mut {}", param_name, type_str),
}
}
fn analyze_stmt(&mut self, stmt: &HirStmt) {
match stmt {
HirStmt::Assign { target, value } => {
if let AssignTarget::Symbol(symbol) = target {
if self.read_only_params.contains(symbol) {
self.mutated_params.insert(symbol.clone());
}
}
self.check_escaping_expr(value);
self.analyze_expr(value);
}
HirStmt::Return(Some(expr)) => {
self.check_escaping_expr(expr);
self.analyze_expr(expr);
}
HirStmt::Expr(expr) => {
self.analyze_expr(expr);
}
HirStmt::If {
condition,
then_body,
else_body,
} => {
self.analyze_expr(condition);
for stmt in then_body {
self.analyze_stmt(stmt);
}
if let Some(else_stmts) = else_body {
for stmt in else_stmts {
self.analyze_stmt(stmt);
}
}
}
HirStmt::While { condition, body } => {
self.analyze_expr(condition);
self.mark_loop_params(body);
for stmt in body {
self.analyze_stmt(stmt);
}
}
HirStmt::For {
target: _,
iter,
body,
} => {
self.analyze_expr(iter);
self.mark_loop_params(body);
for stmt in body {
self.analyze_stmt(stmt);
}
}
_ => {}
}
}
#[allow(clippy::only_used_in_recursion)]
fn analyze_expr(&mut self, expr: &HirExpr) {
match expr {
HirExpr::Binary { op: _, left, right } => {
self.analyze_expr(left);
self.analyze_expr(right);
}
HirExpr::Unary { op: _, operand } => {
self.analyze_expr(operand);
}
HirExpr::Call { func: _, args } => {
for arg in args {
self.analyze_expr(arg);
}
}
HirExpr::List(elts) => {
for elt in elts {
self.analyze_expr(elt);
}
}
HirExpr::Dict(items) => {
for (k, v) in items {
self.analyze_expr(k);
self.analyze_expr(v);
}
}
HirExpr::Tuple(elts) => {
for elt in elts {
self.analyze_expr(elt);
}
}
HirExpr::Index { base, index } => {
self.analyze_expr(base);
self.analyze_expr(index);
}
_ => {}
}
}
fn check_escaping_expr(&mut self, expr: &HirExpr) {
match expr {
HirExpr::Var(name) => {
self.escaping_params.insert(name.clone());
}
HirExpr::List(elts) | HirExpr::Tuple(elts) => {
for elt in elts {
if let HirExpr::Var(name) = elt {
self.escaping_params.insert(name.clone());
}
}
}
_ => {}
}
}
fn mark_loop_params(&mut self, body: &[HirStmt]) {
for stmt in body {
self.find_params_in_stmt(stmt);
}
}
fn find_params_in_stmt(&mut self, stmt: &HirStmt) {
match stmt {
HirStmt::Expr(expr) => self.find_params_in_expr(expr),
HirStmt::Assign { value, .. } => self.find_params_in_expr(value),
_ => {}
}
}
fn find_params_in_expr(&mut self, expr: &HirExpr) {
if let HirExpr::Var(name) = expr {
if self.read_only_params.contains(name)
|| self.mutated_params.contains(name)
|| self.escaping_params.contains(name)
{
self.loop_used_params.insert(name.clone());
}
}
}
fn is_copyable(&self, ty: &Type) -> bool {
matches!(ty, Type::Int | Type::Float | Type::Bool | Type::None)
}
#[allow(clippy::only_used_in_recursion)]
fn type_to_rust_string(&self, ty: &Type) -> String {
match ty {
Type::Unknown => "serde_json::Value".to_string(),
Type::Int => "i32".to_string(),
Type::Float => "f64".to_string(),
Type::String => "String".to_string(),
Type::Bool => "bool".to_string(),
Type::None => "()".to_string(),
Type::List(inner) => format!("Vec<{}>", self.type_to_rust_string(inner)),
Type::Dict(k, v) => format!(
"HashMap<{}, {}>",
self.type_to_rust_string(k),
self.type_to_rust_string(v)
),
Type::Tuple(types) => {
if types.is_empty() {
"()".to_string()
} else {
let type_strs: Vec<String> =
types.iter().map(|t| self.type_to_rust_string(t)).collect();
format!("({})", type_strs.join(", "))
}
}
Type::Optional(inner) => format!("Option<{}>", self.type_to_rust_string(inner)),
Type::Function { .. } => "/* function */".to_string(),
Type::Custom(name) => name.clone(),
Type::TypeVar(name) => name.clone(),
Type::Generic { base, .. } => base.clone(),
Type::Union(_) => "Union".to_string(),
Type::Array { element_type, .. } => {
format!("Array<{}>", self.type_to_rust_string(element_type))
}
Type::Set(element) => format!("HashSet<{}>", self.type_to_rust_string(element)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hir::{BinOp, FunctionProperties, Literal};
use depyler_annotations::TranspilationAnnotations;
use smallvec::smallvec;
#[test]
fn test_read_only_parameter() {
let mut ctx = BorrowingContext::new();
let func = HirFunction {
name: "test".to_string(),
params: smallvec![("x".to_string(), Type::String)],
ret_type: Type::Int,
body: vec![HirStmt::Return(Some(HirExpr::Call {
func: "len".to_string(),
args: vec![HirExpr::Var("x".to_string())],
}))],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
ctx.analyze_function(&func);
assert_eq!(
ctx.get_pattern("x", &Type::String),
BorrowingPattern::Borrowed
);
}
#[test]
fn test_mutated_parameter() {
let mut ctx = BorrowingContext::new();
let func = HirFunction {
name: "test".to_string(),
params: smallvec![("x".to_string(), Type::List(Box::new(Type::Int)))],
ret_type: Type::None,
body: vec![HirStmt::Expr(HirExpr::Call {
func: "append".to_string(),
args: vec![
HirExpr::Var("x".to_string()),
HirExpr::Literal(Literal::Int(42)),
],
})],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
ctx.analyze_function(&func);
}
#[test]
fn test_escaping_parameter() {
let mut ctx = BorrowingContext::new();
let func = HirFunction {
name: "test".to_string(),
params: smallvec![("x".to_string(), Type::String)],
ret_type: Type::String,
body: vec![HirStmt::Return(Some(HirExpr::Var("x".to_string())))],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
ctx.analyze_function(&func);
assert_eq!(ctx.get_pattern("x", &Type::String), BorrowingPattern::Owned);
}
#[test]
fn test_copyable_parameter() {
let mut ctx = BorrowingContext::new();
let func = HirFunction {
name: "test".to_string(),
params: smallvec![("x".to_string(), Type::Int)],
ret_type: Type::Int,
body: vec![HirStmt::Return(Some(HirExpr::Binary {
op: BinOp::Add,
left: Box::new(HirExpr::Var("x".to_string())),
right: Box::new(HirExpr::Literal(Literal::Int(1))),
}))],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
ctx.analyze_function(&func);
assert_eq!(ctx.get_pattern("x", &Type::Int), BorrowingPattern::Owned);
}
#[test]
fn test_generate_param_signature() {
let ctx = BorrowingContext::new();
let mut ctx_borrow = BorrowingContext::new();
ctx_borrow.read_only_params.insert("s".to_string());
assert_eq!(
ctx_borrow.generate_param_signature("s", &Type::String),
"s: &String"
);
assert_eq!(ctx.generate_param_signature("n", &Type::Int), "n: i32");
}
}