use crate::hir::{HirExpr, HirFunction, HirStmt};
use crate::type_mapper::RustType;
use indexmap::IndexMap;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct LifetimeInference {
lifetime_counter: usize,
variable_lifetimes: HashMap<String, LifetimeInfo>,
lifetime_constraints: HashMap<String, HashSet<String>>,
param_analysis: HashMap<String, ParamUsage>,
}
#[derive(Debug, Clone)]
pub struct LifetimeInfo {
pub name: String,
pub is_static: bool,
pub outlives: HashSet<String>,
pub source: LifetimeSource,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LifetimeSource {
Parameter(String),
StaticLiteral,
Local,
Return,
Field(String),
}
#[derive(Debug, Clone, Default)]
pub struct ParamUsage {
pub is_mutated: bool,
pub is_moved: bool,
pub escapes: bool,
pub is_read_only: bool,
pub used_in_loop: bool,
pub has_nested_borrows: bool,
}
#[derive(Debug, Clone)]
pub enum LifetimeConstraint {
Outlives,
Equal,
AtLeast,
}
#[derive(Debug, Clone)]
pub struct LifetimeResult {
pub param_lifetimes: IndexMap<String, InferredParam>,
pub return_lifetime: Option<String>,
pub lifetime_params: Vec<String>,
pub lifetime_bounds: Vec<(String, String)>,
}
#[derive(Debug, Clone)]
pub struct InferredParam {
pub should_borrow: bool,
pub needs_mut: bool,
pub lifetime: Option<String>,
pub rust_type: RustType,
}
impl LifetimeInference {
pub fn new() -> Self {
Self {
lifetime_counter: 0,
variable_lifetimes: HashMap::new(),
lifetime_constraints: HashMap::new(),
param_analysis: HashMap::new(),
}
}
fn next_lifetime(&mut self) -> String {
let name = match self.lifetime_counter {
0 => "'a".to_string(),
1 => "'b".to_string(),
2 => "'c".to_string(),
n => format!("'l{}", n - 2),
};
self.lifetime_counter += 1;
name
}
fn add_constraint(&mut self, from: &str, to: &str, _constraint: LifetimeConstraint) {
self.lifetime_constraints
.entry(from.to_string())
.or_default()
.insert(to.to_string());
}
pub fn analyze_function(&mut self, func: &HirFunction, type_mapper: &crate::type_mapper::TypeMapper) -> LifetimeResult {
self.analyze_parameter_usage(func);
let param_lifetimes = self.infer_parameter_lifetimes(func, type_mapper);
let return_lifetime = self.analyze_return_lifetime(func, type_mapper);
let lifetime_bounds = self.compute_lifetime_bounds();
let mut lifetime_params = HashSet::new();
for param in param_lifetimes.values() {
if let Some(ref lt) = param.lifetime {
lifetime_params.insert(lt.clone());
}
}
if let Some(ref lt) = return_lifetime {
lifetime_params.insert(lt.clone());
}
LifetimeResult {
param_lifetimes,
return_lifetime,
lifetime_params: lifetime_params.into_iter().collect(),
lifetime_bounds,
}
}
fn analyze_parameter_usage(&mut self, func: &HirFunction) {
for (param_name, _param_type) in &func.params {
let mut usage = ParamUsage::default();
for stmt in &func.body {
self.analyze_stmt_for_param(param_name, stmt, &mut usage, false);
}
self.param_analysis.insert(param_name.clone(), usage);
}
}
fn analyze_stmt_for_param(&self, param: &str, stmt: &HirStmt, usage: &mut ParamUsage, in_loop: bool) {
match stmt {
HirStmt::Expr(expr) => self.analyze_expr_for_param(param, expr, usage, in_loop, false),
HirStmt::Assign { target, value } => {
if target == param {
usage.is_mutated = true;
}
self.analyze_expr_for_param(param, value, usage, in_loop, false);
}
HirStmt::Return(value) => {
if let Some(expr) = value {
self.analyze_expr_for_param(param, expr, usage, in_loop, true);
}
}
HirStmt::If { condition, then_body, else_body } => {
self.analyze_expr_for_param(param, condition, usage, in_loop, false);
for stmt in then_body {
self.analyze_stmt_for_param(param, stmt, usage, in_loop);
}
if let Some(else_stmts) = else_body {
for stmt in else_stmts {
self.analyze_stmt_for_param(param, stmt, usage, in_loop);
}
}
}
HirStmt::While { condition, body } => {
self.analyze_expr_for_param(param, condition, usage, true, false);
for stmt in body {
self.analyze_stmt_for_param(param, stmt, usage, true);
}
}
HirStmt::For { iter, body, .. } => {
self.analyze_expr_for_param(param, iter, usage, true, false);
for stmt in body {
self.analyze_stmt_for_param(param, stmt, usage, true);
}
}
}
}
#[allow(clippy::only_used_in_recursion)]
fn analyze_expr_for_param(&self, param: &str, expr: &HirExpr, usage: &mut ParamUsage, in_loop: bool, in_return: bool) {
match expr {
HirExpr::Var(id) => {
if id == param {
usage.is_read_only = true;
if in_return {
usage.escapes = true;
}
if in_loop {
usage.used_in_loop = true;
}
}
}
HirExpr::Attribute { value, .. } => {
if let HirExpr::Var(id) = &**value {
if id == param {
usage.is_read_only = true;
usage.has_nested_borrows = true;
}
}
self.analyze_expr_for_param(param, value, usage, in_loop, in_return);
}
HirExpr::Index { base, index } => {
self.analyze_expr_for_param(param, base, usage, in_loop, false);
self.analyze_expr_for_param(param, index, usage, in_loop, false);
}
HirExpr::Call { func: _, args } => {
for arg in args {
if let HirExpr::Var(id) = arg {
if id == param {
usage.is_moved = true;
}
}
}
for arg in args {
self.analyze_expr_for_param(param, arg, usage, in_loop, false);
}
}
HirExpr::List(elements) | HirExpr::Tuple(elements) => {
for elem in elements {
self.analyze_expr_for_param(param, elem, usage, in_loop, in_return);
}
}
HirExpr::Dict(pairs) => {
for (k, v) in pairs {
self.analyze_expr_for_param(param, k, usage, in_loop, false);
self.analyze_expr_for_param(param, v, usage, in_loop, in_return);
}
}
HirExpr::Binary { left, right, .. } => {
self.analyze_expr_for_param(param, left, usage, in_loop, false);
self.analyze_expr_for_param(param, right, usage, in_loop, false);
}
HirExpr::Unary { operand, .. } => {
self.analyze_expr_for_param(param, operand, usage, in_loop, false);
}
HirExpr::Literal(_) => {}
HirExpr::Borrow { expr, .. } => {
self.analyze_expr_for_param(param, expr, usage, in_loop, in_return);
}
}
}
fn infer_parameter_lifetimes(&mut self, func: &HirFunction, type_mapper: &crate::type_mapper::TypeMapper) -> IndexMap<String, InferredParam> {
let mut result = IndexMap::new();
for (param_name, param_type) in &func.params {
let usage = self.param_analysis.get(param_name).cloned().unwrap_or_default();
let rust_type = type_mapper.map_type(param_type);
let escapes_as_self = usage.escapes && rust_type == type_mapper.map_return_type(&func.ret_type);
let should_borrow = !usage.is_moved && !escapes_as_self && (usage.is_read_only || usage.is_mutated);
let needs_mut = usage.is_mutated;
let lifetime = if should_borrow {
let lt = self.next_lifetime();
self.variable_lifetimes.insert(
param_name.clone(),
LifetimeInfo {
name: lt.clone(),
is_static: false,
outlives: HashSet::new(),
source: LifetimeSource::Parameter(param_name.clone()),
},
);
if usage.escapes {
self.add_constraint(<, "'return", LifetimeConstraint::Outlives);
}
Some(lt)
} else {
None
};
result.insert(
param_name.clone(),
InferredParam {
should_borrow,
needs_mut,
lifetime,
rust_type,
},
);
}
result
}
fn analyze_return_lifetime(&mut self, func: &HirFunction, type_mapper: &crate::type_mapper::TypeMapper) -> Option<String> {
let return_rust_type = type_mapper.map_return_type(&func.ret_type);
if self.return_type_needs_lifetime(&return_rust_type) {
for (param_name, usage) in &self.param_analysis {
if usage.escapes {
if let Some(info) = self.variable_lifetimes.get(param_name) {
return Some(info.name.clone());
}
}
}
Some(self.next_lifetime())
} else {
None
}
}
#[allow(clippy::only_used_in_recursion)]
fn return_type_needs_lifetime(&self, rust_type: &RustType) -> bool {
match rust_type {
RustType::Str { .. } => true,
RustType::Reference { .. } => true,
RustType::Cow { .. } => true,
RustType::Vec(inner) | RustType::Option(inner) => self.return_type_needs_lifetime(inner),
RustType::Result(ok, err) => {
self.return_type_needs_lifetime(ok) || self.return_type_needs_lifetime(err)
}
RustType::Tuple(types) => types.iter().any(|t| self.return_type_needs_lifetime(t)),
_ => false,
}
}
fn compute_lifetime_bounds(&self) -> Vec<(String, String)> {
let mut bounds = Vec::new();
for (from, tos) in &self.lifetime_constraints {
for to in tos {
if to != "'return" { bounds.push((from.clone(), to.clone()));
}
}
}
bounds
}
pub fn apply_elision_rules(&mut self, func: &HirFunction, type_mapper: &crate::type_mapper::TypeMapper) -> Option<LifetimeResult> {
let full_result = self.analyze_function(func, type_mapper);
let ref_params: Vec<_> = full_result.param_lifetimes.iter()
.filter(|(_, param)| param.should_borrow)
.collect();
let return_needs_lifetime = full_result.return_lifetime.is_some();
if ref_params.is_empty() {
return Some(LifetimeResult {
param_lifetimes: full_result.param_lifetimes,
return_lifetime: None,
lifetime_params: vec![],
lifetime_bounds: vec![],
});
}
if ref_params.len() == 1 {
if return_needs_lifetime {
return Some(LifetimeResult {
param_lifetimes: full_result.param_lifetimes,
return_lifetime: None, lifetime_params: vec![], lifetime_bounds: vec![],
});
}
}
if let Some((_name, _)) = ref_params.iter().find(|(name, _)| *name == "self") {
if return_needs_lifetime {
return Some(LifetimeResult {
param_lifetimes: full_result.param_lifetimes,
return_lifetime: None, lifetime_params: vec![],
lifetime_bounds: vec![],
});
}
}
Some(full_result)
}
#[allow(dead_code)]
fn is_reference_type(&self, rust_type: &RustType) -> bool {
matches!(
rust_type,
RustType::Str { .. } | RustType::Reference { .. } | RustType::Cow { .. }
)
}
}
impl Default for LifetimeInference {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hir::{HirFunction, Type as PythonType, FunctionProperties, Literal};
use smallvec::smallvec;
use depyler_annotations::TranspilationAnnotations;
#[test]
fn test_lifetime_generation() {
let mut inference = LifetimeInference::new();
assert_eq!(inference.next_lifetime(), "'a");
assert_eq!(inference.next_lifetime(), "'b");
assert_eq!(inference.next_lifetime(), "'c");
assert_eq!(inference.next_lifetime(), "'l1");
}
#[test]
fn test_parameter_usage_analysis() {
let mut inference = LifetimeInference::new();
let _type_mapper = crate::type_mapper::TypeMapper::new();
let func = HirFunction {
name: "test".to_string(),
params: smallvec![("x".to_string(), PythonType::String)],
ret_type: PythonType::String,
body: vec![
HirStmt::Return(Some(HirExpr::Var("x".to_string())))
],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
inference.analyze_parameter_usage(&func);
let usage = inference.param_analysis.get("x").unwrap();
assert!(usage.is_read_only);
assert!(usage.escapes);
assert!(!usage.is_mutated);
}
#[test]
fn test_lifetime_inference() {
let mut inference = LifetimeInference::new();
let type_mapper = crate::type_mapper::TypeMapper::new();
let func = HirFunction {
name: "get_len".to_string(),
params: smallvec![("s".to_string(), PythonType::String)],
ret_type: PythonType::Int,
body: vec![
HirStmt::Return(Some(HirExpr::Attribute {
value: Box::new(HirExpr::Var("s".to_string())),
attr: "len".to_string(),
}))
],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
let result = inference.analyze_function(&func, &type_mapper);
let s_param = result.param_lifetimes.get("s").unwrap();
assert!(s_param.should_borrow);
assert!(!s_param.needs_mut);
assert!(s_param.lifetime.is_some());
}
#[test]
fn test_elision_rules() {
let mut inference = LifetimeInference::new();
let type_mapper = crate::type_mapper::TypeMapper::new();
let func = HirFunction {
name: "identity".to_string(),
params: smallvec![("x".to_string(), PythonType::String)],
ret_type: PythonType::String,
body: vec![],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
let elision_result = inference.apply_elision_rules(&func, &type_mapper);
assert!(elision_result.is_some());
if let Some(result) = elision_result {
assert!(result.lifetime_params.is_empty());
}
}
#[test]
fn test_mutable_parameter_detection() {
let mut inference = LifetimeInference::new();
let type_mapper = crate::type_mapper::TypeMapper::new();
let func = HirFunction {
name: "append_bang".to_string(),
params: smallvec![("s".to_string(), PythonType::String)],
ret_type: PythonType::None,
body: vec![
HirStmt::Assign {
target: "s".to_string(),
value: HirExpr::Binary {
op: crate::hir::BinOp::Add,
left: Box::new(HirExpr::Var("s".to_string())),
right: Box::new(HirExpr::Literal(Literal::String("!".to_string()))),
},
}
],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
let result = inference.analyze_function(&func, &type_mapper);
let s_param = result.param_lifetimes.get("s").unwrap();
assert!(s_param.should_borrow);
assert!(s_param.needs_mut);
}
}