use crate::hir::{HirExpr, HirFunction, HirStmt, Type};
use anyhow::Result;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Default)]
pub struct TypeVarRegistry {
type_vars: HashMap<String, TypeVarConstraints>,
function_type_params: HashMap<String, Vec<TypeParameter>>,
#[allow(dead_code)]
active_bindings: HashMap<String, Type>,
}
#[derive(Debug, Clone)]
pub struct TypeVarConstraints {
pub name: String,
pub bounds: Vec<TypeBound>,
pub variance: Variance,
pub default: Option<Type>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TypeBound {
UpperBound(Type),
TraitBound(String),
UnionBound(Vec<Type>),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Variance {
Invariant,
Covariant,
Contravariant,
}
#[derive(Debug, Clone)]
pub struct TypeParameter {
pub name: String,
pub bounds: Vec<String>, pub default: Option<Type>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum GenericType {
TypeVar(String),
Generic {
base: Type,
params: Vec<GenericType>,
},
Union(Vec<Type>),
Concrete(Type),
}
impl TypeVarRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_type_var(&mut self, name: String, constraints: TypeVarConstraints) {
self.type_vars.insert(name, constraints);
}
pub fn infer_type_substitutions(&self, func: &HirFunction) -> Result<HashMap<String, Type>> {
let mut inference = TypeInference::new();
inference.analyze_function(func)?;
Ok(inference.substitutions)
}
pub fn apply_substitutions(ty: &Type, substitutions: &HashMap<String, Type>) -> Type {
let substituted = substitutions.get("T");
match ty {
Type::Unknown => {
substituted.cloned().unwrap_or(Type::Unknown)
}
Type::List(inner) => {
Type::List(Box::new(Self::apply_substitutions(inner, substitutions)))
}
Type::Dict(k, v) => Type::Dict(
Box::new(Self::apply_substitutions(k, substitutions)),
Box::new(Self::apply_substitutions(v, substitutions)),
),
Type::Optional(inner) => {
Type::Optional(Box::new(Self::apply_substitutions(inner, substitutions)))
}
Type::Tuple(types) => Type::Tuple(
types
.iter()
.map(|t| Self::apply_substitutions(t, substitutions))
.collect(),
),
other => other.clone(),
}
}
pub fn infer_function_generics(&mut self, func: &HirFunction) -> Result<Vec<TypeParameter>> {
let mut collector = TypeVarCollector::new();
for param in &func.params {
collector.collect_from_type(¶m.ty);
}
collector.collect_from_type(&func.ret_type);
let mut inference = TypeInference::new();
inference.analyze_function(func)?;
let filtered_type_vars: HashSet<String> = collector
.type_vars
.into_iter()
.filter(|tv| !inference.substitutions.contains_key(tv))
.collect();
let filtered_dict_key_vars: HashSet<String> = collector
.dict_key_type_vars
.into_iter()
.filter(|tv| !inference.substitutions.contains_key(tv))
.collect();
let type_params = self.generate_type_parameters(
&filtered_type_vars,
&inference.constraints,
&filtered_dict_key_vars,
)?;
self.function_type_params
.insert(func.name.clone(), type_params.clone());
Ok(type_params)
}
pub fn is_generic(&self, ty: &Type) -> bool {
match ty {
Type::Custom(name) => self.type_vars.contains_key(name),
Type::List(inner) | Type::Optional(inner) => self.is_generic(inner),
Type::Dict(k, v) => self.is_generic(k) || self.is_generic(v),
Type::Tuple(types) => types.iter().any(|t| self.is_generic(t)),
Type::Function { params, ret } => {
params.iter().any(|t| self.is_generic(t)) || self.is_generic(ret)
}
_ => false,
}
}
pub fn to_rust_generic(&self, name: &str, params: &[Type]) -> String {
if params.is_empty() {
name.to_string()
} else {
let param_strs: Vec<String> =
params.iter().map(|p| self.type_to_rust_string(p)).collect();
format!("{}<{}>", name, param_strs.join(", "))
}
}
fn type_to_rust_string(&self, ty: &Type) -> String {
match ty {
Type::Custom(name) if self.type_vars.contains_key(name) => name.clone(),
Type::Int => "i32".to_string(),
Type::Float => "f64".to_string(),
Type::String => "String".to_string(),
Type::Bool => "bool".to_string(),
Type::None => "()".to_string(),
Type::List(inner) => format!("Vec<{}>", self.type_to_rust_string(inner)),
Type::Dict(k, v) => format!(
"HashMap<{}, {}>",
self.type_to_rust_string(k),
self.type_to_rust_string(v)
),
Type::Optional(inner) => format!("Option<{}>", self.type_to_rust_string(inner)),
Type::Tuple(types) => {
let type_strs: Vec<String> =
types.iter().map(|t| self.type_to_rust_string(t)).collect();
format!("({})", type_strs.join(", "))
}
Type::Custom(name) => name.clone(),
_ => "()".to_string(),
}
}
fn generate_type_parameters(
&self,
type_vars: &HashSet<String>,
constraints: &HashMap<String, Vec<TypeConstraint>>,
dict_key_type_vars: &HashSet<String>,
) -> Result<Vec<TypeParameter>> {
let mut params = Vec::new();
for var in type_vars {
let mut bounds = std::collections::HashSet::new();
if let Some(var_constraints) = constraints.get(var) {
for constraint in var_constraints {
match constraint {
TypeConstraint::MustImplement(trait_name) => {
bounds.insert(trait_name.clone());
}
TypeConstraint::MustBe(_ty) => {
continue;
}
TypeConstraint::SubtypeOf(_) => {
}
}
}
}
if dict_key_type_vars.contains(var) {
bounds.insert("Eq".to_string());
bounds.insert("std::hash::Hash".to_string());
}
if bounds.is_empty() {
bounds.insert("Clone".to_string());
} else {
bounds.insert("Clone".to_string());
}
let mut bounds_vec: Vec<String> = bounds.into_iter().collect();
bounds_vec.sort();
params.push(TypeParameter {
name: var.clone(),
bounds: bounds_vec,
default: None,
});
}
Ok(params)
}
}
struct TypeVarCollector {
type_vars: HashSet<String>,
dict_key_type_vars: HashSet<String>,
}
impl TypeVarCollector {
fn new() -> Self {
Self {
type_vars: HashSet::new(),
dict_key_type_vars: HashSet::new(),
}
}
fn collect_from_type(&mut self, ty: &Type) {
self.collect_from_type_internal(ty, false);
}
fn collect_from_type_internal(&mut self, ty: &Type, _nested: bool) {
match ty {
Type::Custom(name) if name.chars().next().is_some_and(|c| c.is_uppercase()) => {
if name.len() == 1 {
self.type_vars.insert(name.clone());
}
}
Type::TypeVar(name) => {
self.type_vars.insert(name.clone());
}
Type::Unknown => {
}
Type::List(inner) | Type::Optional(inner) => {
self.collect_from_type_internal(inner, true);
}
Type::Dict(k, v) => {
if let Type::TypeVar(name) = k.as_ref() {
self.dict_key_type_vars.insert(name.clone());
}
self.collect_from_type_internal(k, true);
self.collect_from_type_internal(v, true);
}
Type::Tuple(types) => {
for t in types {
self.collect_from_type_internal(t, true);
}
}
Type::Function { params, ret } => {
for p in params {
self.collect_from_type_internal(p, true);
}
self.collect_from_type_internal(ret, true);
}
Type::Generic { params, .. } => {
for p in params {
self.collect_from_type_internal(p, true);
}
}
Type::Union(types) => {
for t in types {
self.collect_from_type_internal(t, true);
}
}
Type::Set(inner) => {
self.collect_from_type_internal(inner, true);
}
_ => {}
}
}
#[allow(dead_code)] fn collect_from_stmt(&mut self, stmt: &HirStmt) {
match stmt {
HirStmt::Assign { value, .. } => self.collect_from_expr(value),
HirStmt::Return(Some(expr)) => self.collect_from_expr(expr),
HirStmt::If {
condition,
then_body,
else_body,
} => {
self.collect_from_expr(condition);
for s in then_body {
self.collect_from_stmt(s);
}
if let Some(else_stmts) = else_body {
for s in else_stmts {
self.collect_from_stmt(s);
}
}
}
HirStmt::While { condition, body } => {
self.collect_from_expr(condition);
for s in body {
self.collect_from_stmt(s);
}
}
HirStmt::For { iter, body, .. } => {
self.collect_from_expr(iter);
for s in body {
self.collect_from_stmt(s);
}
}
HirStmt::Expr(expr) => self.collect_from_expr(expr),
_ => {}
}
}
#[allow(dead_code, clippy::only_used_in_recursion)] fn collect_from_expr(&mut self, expr: &HirExpr) {
match expr {
HirExpr::Binary { left, right, .. } => {
self.collect_from_expr(left);
self.collect_from_expr(right);
}
HirExpr::Unary { operand, .. } => self.collect_from_expr(operand),
HirExpr::Call { args, .. } => {
for arg in args {
self.collect_from_expr(arg);
}
}
HirExpr::MethodCall { object, args, .. } => {
self.collect_from_expr(object);
for arg in args {
self.collect_from_expr(arg);
}
}
HirExpr::Index { base, index } => {
self.collect_from_expr(base);
self.collect_from_expr(index);
}
HirExpr::List(elems) => {
for elem in elems {
self.collect_from_expr(elem);
}
}
HirExpr::Dict(pairs) => {
for (k, v) in pairs {
self.collect_from_expr(k);
self.collect_from_expr(v);
}
}
HirExpr::Tuple(elems) => {
for elem in elems {
self.collect_from_expr(elem);
}
}
_ => {}
}
}
}
struct TypeInference {
constraints: HashMap<String, Vec<TypeConstraint>>,
substitutions: HashMap<String, Type>,
param_types: HashMap<String, Type>,
loop_var_to_type_param: HashMap<String, String>,
}
#[derive(Debug, Clone)]
enum TypeConstraint {
#[allow(dead_code)]
MustBe(Type),
#[allow(dead_code)]
SubtypeOf(Type),
MustImplement(String),
}
impl TypeInference {
fn new() -> Self {
Self {
constraints: HashMap::new(),
substitutions: HashMap::new(),
param_types: HashMap::new(),
loop_var_to_type_param: HashMap::new(),
}
}
fn analyze_function(&mut self, func: &HirFunction) -> Result<()> {
for param in &func.params {
self.param_types
.insert(param.name.clone(), param.ty.clone());
}
for param in &func.params {
match ¶m.ty {
Type::Custom(type_var)
if type_var.len() == 1 && type_var.chars().next().unwrap().is_uppercase() =>
{
self.analyze_param_usage(¶m.name, type_var, &func.body)?;
}
Type::TypeVar(type_var) => {
self.analyze_param_usage(¶m.name, type_var, &func.body)?;
}
Type::List(inner) if matches!(**inner, Type::Unknown) => {
self.analyze_param_usage(¶m.name, "T", &func.body)?;
}
Type::Dict(k, v)
if matches!(**k, Type::Unknown) || matches!(**v, Type::Unknown) =>
{
self.analyze_param_usage(¶m.name, "T", &func.body)?;
}
Type::Optional(inner) if matches!(**inner, Type::Unknown) => {
if let Some(concrete_type) =
self.infer_concrete_type_from_returns(¶m.name, &func.body)
{
self.substitutions.insert("T".to_string(), concrete_type);
}
}
_ => {}
}
}
Ok(())
}
fn analyze_param_usage(
&mut self,
param_name: &str,
type_var: &str,
body: &[HirStmt],
) -> Result<()> {
for stmt in body {
self.analyze_stmt_for_param(param_name, type_var, stmt)?;
}
Ok(())
}
fn analyze_stmt_for_param(
&mut self,
param_name: &str,
type_var: &str,
stmt: &HirStmt,
) -> Result<()> {
match stmt {
HirStmt::Expr(expr) => {
self.analyze_expr_for_param(param_name, type_var, expr)?;
}
HirStmt::Assign { value, .. } => {
self.analyze_expr_for_param(param_name, type_var, value)?;
}
HirStmt::Return(Some(expr)) => {
self.analyze_expr_for_param(param_name, type_var, expr)?;
}
HirStmt::If {
condition,
then_body,
else_body,
} => {
self.analyze_expr_for_param(param_name, type_var, condition)?;
for s in then_body {
self.analyze_stmt_for_param(param_name, type_var, s)?;
}
if let Some(else_stmts) = else_body {
for s in else_stmts {
self.analyze_stmt_for_param(param_name, type_var, s)?;
}
}
}
HirStmt::For { target, iter, body } => {
self.analyze_expr_for_param(param_name, type_var, iter)?;
if let HirExpr::Var(iter_var) = iter {
if iter_var == param_name {
if let crate::hir::AssignTarget::Symbol(loop_var) = target {
self.loop_var_to_type_param
.insert(loop_var.clone(), type_var.to_string());
}
}
}
for s in body {
self.analyze_stmt_for_param(param_name, type_var, s)?;
}
}
HirStmt::While { condition, body } => {
self.analyze_expr_for_param(param_name, type_var, condition)?;
for s in body {
self.analyze_stmt_for_param(param_name, type_var, s)?;
}
}
_ => {}
}
Ok(())
}
fn analyze_expr_for_param(
&mut self,
param_name: &str,
type_var: &str,
expr: &HirExpr,
) -> Result<()> {
match expr {
HirExpr::Binary { left, right, op } => {
self.check_binary_op_usage(param_name, type_var, left, right, *op)?;
self.analyze_expr_for_param(param_name, type_var, left)?;
self.analyze_expr_for_param(param_name, type_var, right)?;
}
HirExpr::MethodCall {
object,
method,
args,
..
} => {
if let HirExpr::Var(var) = object.as_ref() {
if var == param_name {
self.add_method_constraint(type_var, method);
if (method == "get" || method == "__getitem__") && !args.is_empty() {
let key_arg = &args[0];
let key_is_string = match key_arg {
HirExpr::Literal(crate::hir::Literal::String(_)) => true,
HirExpr::Var(v) => self.is_string_typed(v),
_ => false,
};
if key_is_string {
self.substitutions
.insert(type_var.to_string(), Type::String);
}
}
}
}
self.analyze_expr_for_param(param_name, type_var, object)?;
for arg in args {
self.analyze_expr_for_param(param_name, type_var, arg)?;
}
}
HirExpr::Call { args, .. } => {
for arg in args {
self.analyze_expr_for_param(param_name, type_var, arg)?;
}
}
_ => {}
}
Ok(())
}
fn add_method_constraint(&mut self, type_var: &str, method: &str) {
let constraint = match method {
"len" => TypeConstraint::MustImplement("HasLen".to_string()),
"push" | "pop" => TypeConstraint::MustImplement("VecLike".to_string()),
"clone" => TypeConstraint::MustImplement("Clone".to_string()),
_ => return,
};
self.constraints
.entry(type_var.to_string())
.or_default()
.push(constraint);
}
fn check_binary_op_usage(
&mut self,
param_name: &str,
type_var: &str,
left: &HirExpr,
right: &HirExpr,
op: crate::hir::BinOp,
) -> Result<()> {
use crate::hir::BinOp;
let (uses_param, other_operand) = match (left, right) {
(HirExpr::Var(l), other) if l == param_name => (true, Some(other)),
(other, HirExpr::Var(r)) if r == param_name => (true, Some(other)),
(HirExpr::Var(l), other)
if self
.loop_var_to_type_param
.get(l)
.is_some_and(|tv| tv == type_var) =>
{
(true, Some(other))
}
(other, HirExpr::Var(r))
if self
.loop_var_to_type_param
.get(r)
.is_some_and(|tv| tv == type_var) =>
{
(true, Some(other))
}
_ => (false, None),
};
if uses_param {
let constraint = match op {
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => {
TypeConstraint::MustImplement("std::ops::Add".to_string())
}
BinOp::Eq | BinOp::NotEq => {
let target_is_string = other_operand.is_some_and(|op| match op {
HirExpr::Literal(crate::hir::Literal::String(_)) => true,
HirExpr::Var(v) => self.is_string_typed(v),
_ => false,
});
if target_is_string {
self.substitutions
.insert(type_var.to_string(), Type::String);
TypeConstraint::MustImplement("PartialEq".to_string())
} else {
TypeConstraint::MustImplement("PartialEq".to_string())
}
}
BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
TypeConstraint::MustImplement("PartialOrd".to_string())
}
BinOp::In | BinOp::NotIn => {
if let HirExpr::Var(r) = right {
if r == param_name {
let key_is_string = match left {
HirExpr::Literal(crate::hir::Literal::String(_)) => true,
HirExpr::Var(v) => self.is_string_typed(v),
_ => false,
};
if key_is_string {
self.substitutions
.insert(type_var.to_string(), Type::String);
}
}
}
return Ok(()); }
_ => return Ok(()),
};
self.constraints
.entry(type_var.to_string())
.or_default()
.push(constraint);
}
Ok(())
}
fn is_string_typed(&self, var_name: &str) -> bool {
self.param_types
.get(var_name)
.is_some_and(|ty| matches!(ty, Type::String))
}
#[allow(dead_code)]
fn infer_concrete_type_from_returns(
&self,
_optional_param_name: &str,
body: &[HirStmt],
) -> Option<Type> {
let mut concrete_types = Vec::new();
self.collect_return_types_from_stmts(body, &mut concrete_types);
concrete_types.into_iter().find(|ty| {
!matches!(
ty,
Type::Optional(_) | Type::Unknown | Type::None | Type::TypeVar(_)
)
})
}
fn collect_return_types_from_stmts(&self, stmts: &[HirStmt], types: &mut Vec<Type>) {
for stmt in stmts {
match stmt {
HirStmt::Return(Some(expr)) => {
if let Some(ty) = self.infer_expr_type(expr) {
types.push(ty);
}
}
HirStmt::If {
then_body,
else_body,
..
} => {
self.collect_return_types_from_stmts(then_body, types);
if let Some(else_stmts) = else_body {
self.collect_return_types_from_stmts(else_stmts, types);
}
}
HirStmt::While { body, .. } | HirStmt::For { body, .. } => {
self.collect_return_types_from_stmts(body, types);
}
HirStmt::With { body, .. } | HirStmt::Try { body, .. } => {
self.collect_return_types_from_stmts(body, types);
}
_ => {}
}
}
}
fn infer_expr_type(&self, expr: &HirExpr) -> Option<Type> {
use crate::hir::Literal;
match expr {
HirExpr::Literal(lit) => match lit {
Literal::Int(_) => Some(Type::Int),
Literal::Float(_) => Some(Type::Float),
Literal::String(_) => Some(Type::String),
Literal::Bool(_) => Some(Type::Bool),
Literal::None => Some(Type::None),
_ => None,
},
HirExpr::Var(name) => {
self.param_types.get(name).cloned()
}
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hir::*;
#[test]
fn test_type_var_registry_new() {
let registry = TypeVarRegistry::new();
assert!(registry.type_vars.is_empty());
assert!(registry.function_type_params.is_empty());
}
#[test]
fn test_type_var_registry_register() {
let mut registry = TypeVarRegistry::new();
let constraints = TypeVarConstraints {
name: "T".to_string(),
bounds: vec![TypeBound::TraitBound("Clone".to_string())],
variance: Variance::Covariant,
default: None,
};
registry.register_type_var("T".to_string(), constraints);
assert!(registry.type_vars.contains_key("T"));
}
#[test]
fn test_apply_substitutions_unknown() {
let mut subs = HashMap::new();
subs.insert("T".to_string(), Type::String);
let result = TypeVarRegistry::apply_substitutions(&Type::Unknown, &subs);
assert_eq!(result, Type::String);
}
#[test]
fn test_apply_substitutions_no_match() {
let subs = HashMap::new();
let result = TypeVarRegistry::apply_substitutions(&Type::Unknown, &subs);
assert_eq!(result, Type::Unknown);
}
#[test]
fn test_apply_substitutions_list() {
let mut subs = HashMap::new();
subs.insert("T".to_string(), Type::Int);
let list_type = Type::List(Box::new(Type::Unknown));
let result = TypeVarRegistry::apply_substitutions(&list_type, &subs);
assert_eq!(result, Type::List(Box::new(Type::Int)));
}
#[test]
fn test_apply_substitutions_dict() {
let mut subs = HashMap::new();
subs.insert("T".to_string(), Type::Int);
let dict_type = Type::Dict(Box::new(Type::String), Box::new(Type::Unknown));
let result = TypeVarRegistry::apply_substitutions(&dict_type, &subs);
assert_eq!(
result,
Type::Dict(Box::new(Type::String), Box::new(Type::Int))
);
}
#[test]
fn test_apply_substitutions_optional() {
let mut subs = HashMap::new();
subs.insert("T".to_string(), Type::Float);
let opt_type = Type::Optional(Box::new(Type::Unknown));
let result = TypeVarRegistry::apply_substitutions(&opt_type, &subs);
assert_eq!(result, Type::Optional(Box::new(Type::Float)));
}
#[test]
fn test_apply_substitutions_tuple() {
let mut subs = HashMap::new();
subs.insert("T".to_string(), Type::Bool);
let tuple_type = Type::Tuple(vec![Type::Int, Type::Unknown, Type::String]);
let result = TypeVarRegistry::apply_substitutions(&tuple_type, &subs);
assert_eq!(
result,
Type::Tuple(vec![Type::Int, Type::Bool, Type::String])
);
}
#[test]
fn test_apply_substitutions_concrete() {
let subs = HashMap::new();
let result = TypeVarRegistry::apply_substitutions(&Type::Int, &subs);
assert_eq!(result, Type::Int);
}
#[test]
fn test_is_generic_with_type_var() {
let mut registry = TypeVarRegistry::new();
registry.register_type_var(
"T".to_string(),
TypeVarConstraints {
name: "T".to_string(),
bounds: vec![],
variance: Variance::Invariant,
default: None,
},
);
assert!(registry.is_generic(&Type::Custom("T".to_string())));
assert!(!registry.is_generic(&Type::Custom("String".to_string())));
}
#[test]
fn test_is_generic_nested() {
let mut registry = TypeVarRegistry::new();
registry.register_type_var(
"T".to_string(),
TypeVarConstraints {
name: "T".to_string(),
bounds: vec![],
variance: Variance::Invariant,
default: None,
},
);
assert!(registry.is_generic(&Type::List(Box::new(Type::Custom("T".to_string())))));
assert!(registry.is_generic(&Type::Optional(Box::new(Type::Custom("T".to_string())))));
assert!(registry.is_generic(&Type::Dict(
Box::new(Type::String),
Box::new(Type::Custom("T".to_string()))
)));
}
#[test]
fn test_is_generic_tuple() {
let mut registry = TypeVarRegistry::new();
registry.register_type_var(
"T".to_string(),
TypeVarConstraints {
name: "T".to_string(),
bounds: vec![],
variance: Variance::Invariant,
default: None,
},
);
assert!(registry.is_generic(&Type::Tuple(vec![Type::Int, Type::Custom("T".to_string())])));
assert!(!registry.is_generic(&Type::Tuple(vec![Type::Int, Type::String])));
}
#[test]
fn test_is_generic_function() {
let mut registry = TypeVarRegistry::new();
registry.register_type_var(
"T".to_string(),
TypeVarConstraints {
name: "T".to_string(),
bounds: vec![],
variance: Variance::Invariant,
default: None,
},
);
assert!(registry.is_generic(&Type::Function {
params: vec![Type::Custom("T".to_string())],
ret: Box::new(Type::Int),
}));
assert!(registry.is_generic(&Type::Function {
params: vec![Type::Int],
ret: Box::new(Type::Custom("T".to_string())),
}));
}
#[test]
fn test_to_rust_generic_no_params() {
let registry = TypeVarRegistry::new();
assert_eq!(registry.to_rust_generic("Vec", &[]), "Vec");
}
#[test]
fn test_to_rust_generic_with_params() {
let registry = TypeVarRegistry::new();
let result = registry.to_rust_generic("HashMap", &[Type::String, Type::Int]);
assert_eq!(result, "HashMap<String, i32>");
}
#[test]
fn test_type_to_rust_string() {
let registry = TypeVarRegistry::new();
assert_eq!(registry.type_to_rust_string(&Type::Int), "i32");
assert_eq!(registry.type_to_rust_string(&Type::Float), "f64");
assert_eq!(registry.type_to_rust_string(&Type::String), "String");
assert_eq!(registry.type_to_rust_string(&Type::Bool), "bool");
assert_eq!(registry.type_to_rust_string(&Type::None), "()");
}
#[test]
fn test_type_to_rust_string_containers() {
let registry = TypeVarRegistry::new();
assert_eq!(
registry.type_to_rust_string(&Type::List(Box::new(Type::Int))),
"Vec<i32>"
);
assert_eq!(
registry.type_to_rust_string(&Type::Dict(Box::new(Type::String), Box::new(Type::Int))),
"HashMap<String, i32>"
);
assert_eq!(
registry.type_to_rust_string(&Type::Optional(Box::new(Type::Int))),
"Option<i32>"
);
}
#[test]
fn test_type_to_rust_string_tuple() {
let registry = TypeVarRegistry::new();
let tuple = Type::Tuple(vec![Type::Int, Type::String, Type::Bool]);
assert_eq!(registry.type_to_rust_string(&tuple), "(i32, String, bool)");
}
#[test]
fn test_type_var_constraints_clone() {
let constraints = TypeVarConstraints {
name: "T".to_string(),
bounds: vec![TypeBound::TraitBound("Clone".to_string())],
variance: Variance::Covariant,
default: Some(Type::Int),
};
let cloned = constraints.clone();
assert_eq!(cloned.name, "T");
assert_eq!(cloned.variance, Variance::Covariant);
}
#[test]
fn test_type_var_constraints_debug() {
let constraints = TypeVarConstraints {
name: "T".to_string(),
bounds: vec![],
variance: Variance::Invariant,
default: None,
};
let debug = format!("{:?}", constraints);
assert!(debug.contains("TypeVarConstraints"));
assert!(debug.contains("name"));
}
#[test]
fn test_type_bound_equality() {
let bound1 = TypeBound::TraitBound("Clone".to_string());
let bound2 = TypeBound::TraitBound("Clone".to_string());
let bound3 = TypeBound::TraitBound("Debug".to_string());
assert_eq!(bound1, bound2);
assert_ne!(bound1, bound3);
}
#[test]
fn test_type_bound_upper_bound() {
let bound = TypeBound::UpperBound(Type::String);
assert!(matches!(bound, TypeBound::UpperBound(Type::String)));
}
#[test]
fn test_type_bound_union_bound() {
let bound = TypeBound::UnionBound(vec![Type::Int, Type::String]);
if let TypeBound::UnionBound(types) = bound {
assert_eq!(types.len(), 2);
} else {
panic!("Expected UnionBound");
}
}
#[test]
fn test_variance_equality() {
assert_eq!(Variance::Invariant, Variance::Invariant);
assert_eq!(Variance::Covariant, Variance::Covariant);
assert_eq!(Variance::Contravariant, Variance::Contravariant);
assert_ne!(Variance::Invariant, Variance::Covariant);
}
#[test]
fn test_variance_copy() {
let v = Variance::Covariant;
let v2 = v; assert_eq!(v, v2);
}
#[test]
fn test_type_parameter_clone() {
let param = TypeParameter {
name: "T".to_string(),
bounds: vec!["Clone".to_string(), "Debug".to_string()],
default: Some(Type::Int),
};
let cloned = param.clone();
assert_eq!(cloned.name, "T");
assert_eq!(cloned.bounds.len(), 2);
}
#[test]
fn test_generic_type_type_var() {
let gt = GenericType::TypeVar("T".to_string());
assert!(matches!(gt, GenericType::TypeVar(s) if s == "T"));
}
#[test]
fn test_generic_type_generic() {
let gt = GenericType::Generic {
base: Type::List(Box::new(Type::Unknown)),
params: vec![GenericType::TypeVar("T".to_string())],
};
if let GenericType::Generic { params, .. } = gt {
assert_eq!(params.len(), 1);
}
}
#[test]
fn test_generic_type_union() {
let gt = GenericType::Union(vec![Type::Int, Type::String]);
if let GenericType::Union(types) = gt {
assert_eq!(types.len(), 2);
}
}
#[test]
fn test_generic_type_concrete() {
let gt = GenericType::Concrete(Type::Int);
assert_eq!(gt, GenericType::Concrete(Type::Int));
}
#[test]
fn test_generic_type_clone() {
let gt = GenericType::TypeVar("T".to_string());
let cloned = gt.clone();
assert_eq!(gt, cloned);
}
#[test]
fn test_type_var_detection() {
let mut collector = TypeVarCollector::new();
collector.collect_from_type(&Type::Custom("T".to_string()));
collector.collect_from_type(&Type::List(Box::new(Type::Custom("U".to_string()))));
collector.collect_from_type(&Type::Custom("MyClass".to_string()));
assert!(collector.type_vars.contains("T"));
assert!(collector.type_vars.contains("U"));
assert!(!collector.type_vars.contains("MyClass"));
}
#[test]
fn test_type_var_collector_type_var_type() {
let mut collector = TypeVarCollector::new();
collector.collect_from_type(&Type::TypeVar("V".to_string()));
assert!(collector.type_vars.contains("V"));
}
#[test]
fn test_type_var_collector_dict_key() {
let mut collector = TypeVarCollector::new();
let dict = Type::Dict(
Box::new(Type::TypeVar("K".to_string())),
Box::new(Type::TypeVar("V".to_string())),
);
collector.collect_from_type(&dict);
assert!(collector.type_vars.contains("K"));
assert!(collector.type_vars.contains("V"));
assert!(collector.dict_key_type_vars.contains("K"));
}
#[test]
fn test_type_var_collector_function() {
let mut collector = TypeVarCollector::new();
let func = Type::Function {
params: vec![Type::Custom("T".to_string())],
ret: Box::new(Type::Custom("U".to_string())),
};
collector.collect_from_type(&func);
assert!(collector.type_vars.contains("T"));
assert!(collector.type_vars.contains("U"));
}
#[test]
fn test_type_var_collector_generic() {
let mut collector = TypeVarCollector::new();
let generic = Type::Generic {
base: "Either".to_string(),
params: vec![Type::Custom("L".to_string()), Type::Custom("R".to_string())],
};
collector.collect_from_type(&generic);
assert!(collector.type_vars.contains("L"));
assert!(collector.type_vars.contains("R"));
}
#[test]
fn test_type_var_collector_union() {
let mut collector = TypeVarCollector::new();
let union = Type::Union(vec![
Type::Custom("A".to_string()),
Type::Custom("B".to_string()),
]);
collector.collect_from_type(&union);
assert!(collector.type_vars.contains("A"));
assert!(collector.type_vars.contains("B"));
}
#[test]
fn test_type_var_collector_set() {
let mut collector = TypeVarCollector::new();
let set = Type::Set(Box::new(Type::Custom("T".to_string())));
collector.collect_from_type(&set);
assert!(collector.type_vars.contains("T"));
}
#[test]
fn test_generic_function_inference() {
let mut registry = TypeVarRegistry::new();
let func = HirFunction {
name: "identity".to_string(),
params: smallvec::smallvec![HirParam::new(
"x".to_string(),
Type::Custom("T".to_string())
)],
ret_type: Type::Custom("T".to_string()),
body: vec![HirStmt::Return(Some(HirExpr::Var("x".to_string())))],
properties: FunctionProperties::default(),
annotations: Default::default(),
docstring: None,
};
let type_params = registry.infer_function_generics(&func).unwrap();
assert_eq!(type_params.len(), 1);
assert_eq!(type_params[0].name, "T");
}
#[test]
fn test_infer_type_substitutions() {
let registry = TypeVarRegistry::new();
let func = HirFunction {
name: "get_first".to_string(),
params: smallvec::smallvec![HirParam::new(
"items".to_string(),
Type::List(Box::new(Type::String))
)],
ret_type: Type::String,
body: vec![],
properties: FunctionProperties::default(),
annotations: Default::default(),
docstring: None,
};
let subs = registry.infer_type_substitutions(&func).unwrap();
assert!(subs.is_empty() || subs.contains_key("T"));
}
#[test]
fn test_constraint_inference() {
let mut inference = TypeInference::new();
inference.add_method_constraint("T", "len");
assert!(inference.constraints["T"]
.iter()
.any(|c| { matches!(c, TypeConstraint::MustImplement(s) if s == "HasLen") }));
}
#[test]
fn test_type_inference_new() {
let inference = TypeInference::new();
assert!(inference.constraints.is_empty());
assert!(inference.substitutions.is_empty());
}
#[test]
fn test_add_method_constraint_push() {
let mut inference = TypeInference::new();
inference.add_method_constraint("T", "push");
assert!(inference.constraints["T"]
.iter()
.any(|c| matches!(c, TypeConstraint::MustImplement(s) if s == "VecLike")));
}
#[test]
fn test_add_method_constraint_clone() {
let mut inference = TypeInference::new();
inference.add_method_constraint("T", "clone");
assert!(inference.constraints["T"]
.iter()
.any(|c| matches!(c, TypeConstraint::MustImplement(s) if s == "Clone")));
}
#[test]
fn test_add_method_constraint_unknown() {
let mut inference = TypeInference::new();
inference.add_method_constraint("T", "unknown_method");
assert!(!inference.constraints.contains_key("T"));
}
}