#[cfg(test)]
mod tests {
use crate::core::{ConstructType, ParameterInfo, ReplaceInfo};
use crate::migrate_ruff::migrate_file;
use crate::tests::test_utils::TestContext;
use crate::types::TypeIntrospectionMethod;
use std::collections::HashMap;
use std::path::Path;
fn migrate_source_with_replacements(
source: &str,
replacements: HashMap<String, ReplaceInfo>,
) -> String {
let test_ctx = TestContext::new(source);
let mut type_context = test_ctx.create_type_context(TypeIntrospectionMethod::PyrightLsp);
let result = migrate_file(
source,
"test_module",
Path::new(&test_ctx.file_path),
&mut type_context,
replacements,
HashMap::new(),
)
.unwrap();
drop(test_ctx);
result
}
fn create_test_with_base_classes(test_code: &str) -> String {
format!(
r#"
# Test base classes
class BaseRepo:
def do_commit(self, message, **kwargs):
pass
def stage(self, fs_paths):
pass
def get_worktree(self):
return WorkTree()
def reset_index(self, tree=None):
pass
def do_something(self, **kwargs):
pass
class WorkTree:
def stage(self, fs_paths):
pass
def unstage(self, fs_paths):
pass
def commit(self, message=None, **kwargs):
pass
def reset_index(self, tree=None):
pass
class Repo(BaseRepo):
def stage(self, fs_paths):
pass
@staticmethod
def init(path) -> 'Repo':
return Repo()
class Index:
def __init__(self, path):
pass
def get_entry(self, path):
return IndexEntry()
class IndexEntry:
def stage(self):
return 0
{}
"#,
test_code
)
}
fn create_replacement_info(
old_name: &str,
replacement_expr: &str,
parameters: Vec<&str>,
) -> ReplaceInfo {
let param_list = parameters.join(", ");
let mut python_return = replacement_expr.to_string();
python_return = python_return.replace("{**kwargs}", "**kwargs");
python_return = python_return.replace("{*args}", "*args");
for param in ¶meters {
let param_clean = param.trim_start_matches("**").trim_start_matches("*");
if !param.starts_with("**") && !param.starts_with("*") {
python_return = python_return.replace(&format!("{{{}}}", param), param_clean);
}
}
python_return = python_return
.replace("{self}", "self")
.replace("{", "")
.replace("}", "");
let func_name_only = old_name.split('.').next_back().unwrap_or(old_name);
let function_code = format!(
"def {}({}):\n return {}",
func_name_only, param_list, python_return
);
let replacement_ast = match rustpython_parser::parse(
&function_code,
rustpython_parser::Mode::Module,
"<test>",
) {
Ok(rustpython_ast::Mod::Module(module)) => {
if let Some(rustpython_ast::Stmt::FunctionDef(func)) = module.body.first() {
if let Some(rustpython_ast::Stmt::Return(ret)) = func.body.first() {
ret.value.clone()
} else {
None
}
} else {
None
}
}
_ => None,
};
ReplaceInfo {
old_name: old_name.to_string(),
replacement_expr: replacement_expr.to_string(),
replacement_ast,
construct_type: ConstructType::Function,
parameters: parameters
.iter()
.map(|&name| {
if let Some(stripped) = name.strip_prefix("**") {
ParameterInfo::kwarg(stripped)
} else if let Some(stripped) = name.strip_prefix("*") {
ParameterInfo::vararg(stripped)
} else {
ParameterInfo::new(name)
}
})
.collect(),
return_type: None,
since: None,
remove_in: None,
message: None,
}
}
#[test]
fn test_worktree_double_access_issue() {
let test_code = r#"
def test_worktree_operations():
# Create a WorkTree instance
worktree: WorkTree = WorkTree()
# This should NOT be migrated - worktree is already a WorkTree object
worktree.stage(["file.txt"])
worktree.unstage(["file.txt"])
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.Repo.stage".to_string(),
create_replacement_info(
"stage",
"{self}.get_worktree().stage({fs_paths})",
vec!["self", "fs_paths"],
),
);
let test_ctx = TestContext::new(&source);
let mut type_context = test_ctx.create_type_context(TypeIntrospectionMethod::PyrightLsp);
let result = migrate_file(
&source,
"test_module",
Path::new(&test_ctx.file_path),
&mut type_context,
replacements,
HashMap::new(),
)
.unwrap();
drop(test_ctx);
assert!(result.contains("worktree.stage"));
assert!(result.contains("worktree.unstage"));
assert!(!result.contains("worktree.get_worktree().stage"));
assert!(!result.contains("worktree.get_worktree().unstage"));
}
#[test]
fn test_parameter_expansion_with_kwargs() {
let test_code = r#"
repo = BaseRepo()
repo.do_commit(
b"Initial commit",
committer=b"Test Committer <test@nodomain.com>",
author=b"Test Author <test@nodomain.com>",
commit_timestamp=12345,
commit_timezone=0,
author_timestamp=12345,
author_timezone=0,
)
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.BaseRepo.do_commit".to_string(),
create_replacement_info(
"do_commit",
"{self}.get_worktree().commit(message={message}, {**kwargs})",
vec!["self", "message", "**kwargs"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
if !result.contains("message=b\"Initial commit\"") {
eprintln!("Expected 'message=b\"Initial commit\"', got:\n{}", result);
}
assert!(result.contains("repo.get_worktree().commit("));
assert!(result.contains("message=b\"Initial commit\""));
assert!(result.contains("committer=b\"Test Committer <test@nodomain.com>\""));
let lines: Vec<&str> = result.lines().collect();
let commit_line = lines
.iter()
.find(|line| line.contains("repo.get_worktree().commit("))
.expect("Should find the migrated commit line");
assert!(
!commit_line.contains("tree="),
"The migrated commit call should not have tree= parameter"
);
}
#[test]
fn test_default_parameter_pollution() {
let test_code = r#"
repo = BaseRepo()
repo.do_commit(b"Simple commit")
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
let params = vec![
ParameterInfo {
name: "self".to_string(),
has_default: false,
default_value: None,
is_vararg: false,
is_kwarg: false,
is_kwonly: false,
},
ParameterInfo {
name: "message".to_string(),
has_default: true,
default_value: Some("None".to_string()),
is_vararg: false,
is_kwarg: false,
is_kwonly: false,
},
ParameterInfo {
name: "tree".to_string(),
has_default: true,
default_value: Some("None".to_string()),
is_vararg: false,
is_kwarg: false,
is_kwonly: false,
},
ParameterInfo {
name: "encoding".to_string(),
has_default: true,
default_value: Some("None".to_string()),
is_vararg: false,
is_kwarg: false,
is_kwonly: false,
},
];
replacements.insert(
"test_module.BaseRepo.do_commit".to_string(),
ReplaceInfo {
old_name: "do_commit".to_string(),
replacement_expr: "{self}.get_worktree().commit(message={message})".to_string(),
replacement_ast: None,
construct_type: ConstructType::Function,
parameters: params,
return_type: None,
since: None,
remove_in: None,
message: None,
},
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("repo.get_worktree().commit(message=b\"Simple commit\")"));
let commit_call = "repo.get_worktree().commit(message=b\"Simple commit\")";
assert!(result.contains(commit_call));
assert!(!result.contains("commit(message=b\"Simple commit\", tree="));
assert!(!result.contains("commit(message=b\"Simple commit\", encoding="));
}
#[test]
fn test_incomplete_migration_stage_and_commit() {
let test_code = r#"
# Inline the operations so pyright can track the type
r = Repo()
r.stage(["file.txt"])
r.do_commit("test commit")
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.Repo.stage".to_string(),
create_replacement_info(
"stage",
"{self}.get_worktree().stage({fs_paths})",
vec!["self", "fs_paths"],
),
);
replacements.insert(
"test_module.BaseRepo.do_commit".to_string(),
create_replacement_info(
"do_commit",
"{self}.get_worktree().commit(message={message})",
vec!["self", "message"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
if !result.contains("r.get_worktree().commit(message=\"test commit\")") {
eprintln!("Expected commit migration, got:");
for line in result.lines() {
if line.contains("commit") || line.contains("do_commit") {
eprintln!(" {}", line);
}
}
}
assert!(result.contains("r.get_worktree().stage([\"file.txt\"])"));
assert!(result.contains("r.get_worktree().commit(message=\"test commit\")"));
}
#[test]
fn test_worktree_stage_calls() {
let test_code = r#"
wt = WorkTree()
wt.stage(["file1.txt", "file2.txt"])
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.Repo.stage".to_string(),
create_replacement_info(
"stage",
"{self}.get_worktree().stage({fs_paths})",
vec!["self", "fs_paths"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("wt.stage([\"file1.txt\", \"file2.txt\"])"));
assert!(!result.contains("wt.get_worktree()"));
}
#[test]
fn test_unprovided_parameter_placeholders() {
let test_code = r#"
repo = BaseRepo()
target = repo
target.reset_index()
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
let params = vec![
ParameterInfo {
name: "self".to_string(),
has_default: false,
default_value: None,
is_vararg: false,
is_kwarg: false,
is_kwonly: false,
},
ParameterInfo {
name: "tree".to_string(),
has_default: true,
default_value: Some("None".to_string()),
is_vararg: false,
is_kwarg: false,
is_kwonly: false,
},
];
let python_return = "{self}.get_worktree().reset_index({tree})"
.replace("{self}", "self")
.replace("{tree}", "tree");
let function_code = format!("def reset_index(self, tree):\n return {}", python_return);
let replacement_ast = match rustpython_parser::parse(
&function_code,
rustpython_parser::Mode::Module,
"<test>",
) {
Ok(rustpython_ast::Mod::Module(module)) => {
if let Some(rustpython_ast::Stmt::FunctionDef(func)) = module.body.first() {
if let Some(rustpython_ast::Stmt::Return(ret)) = func.body.first() {
ret.value.clone()
} else {
None
}
} else {
None
}
}
_ => None,
};
replacements.insert(
"test_module.BaseRepo.reset_index".to_string(),
ReplaceInfo {
old_name: "reset_index".to_string(),
replacement_expr: "{self}.get_worktree().reset_index({tree})".to_string(),
replacement_ast,
construct_type: ConstructType::Function,
parameters: params,
return_type: None,
since: None,
remove_in: None,
message: None,
},
);
let result = migrate_source_with_replacements(&source, replacements);
println!("Test source:\n{}", source);
println!("Migration result:\n{}", result);
assert!(result.contains("target.get_worktree().reset_index()"));
assert!(!result.contains("{tree}"));
}
#[test]
fn test_kwarg_pattern_detection() {
let test_code = r#"
def process(data, mode="fast"):
process_v2(data, mode)
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.process_v2".to_string(),
create_replacement_info(
"process_v2",
"process_v2({data}, processing_mode={mode})",
vec!["data", "mode"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("process_v2(data, processing_mode=mode)"));
}
#[test]
fn test_kwargs_passthrough() {
let test_code = r#"
repo = BaseRepo()
repo.do_something(a=1, b=2, c=3)
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.BaseRepo.do_something".to_string(),
create_replacement_info(
"do_something",
"{self}.new_method({**kwargs})",
vec!["self", "**kwargs"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("repo.new_method(a=1, b=2, c=3)"));
}
#[test]
fn test_kwargs_with_dict_expansion() {
let test_code = r#"
repo = BaseRepo()
commit_kwargs = {"author": "Test"}
repo.do_something(**commit_kwargs)
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.BaseRepo.do_something".to_string(),
create_replacement_info(
"do_something",
"{self}.new_method({**kwargs})",
vec!["self", "**kwargs"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("repo.new_method(**commit_kwargs)"));
}
#[test]
fn test_dict_unpacking_without_kwarg_param() {
let test_code = r#"
def process_data(a, b):
return a + b
extra_args = {"b": 2}
result = process_data(1, **extra_args)
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.process_data".to_string(),
create_replacement_info(
"test_module.process_data",
"new_process({a}, {b})",
vec!["a", "b"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
if !result.contains("result = new_process(1, **extra_args)") {
eprintln!("Expected 'result = new_process(1, **extra_args)', got:");
for line in result.lines() {
if line.contains("new_process") {
eprintln!(" {}", line);
}
}
}
assert!(result.contains("result = new_process(1, **extra_args)"));
}
#[test]
fn test_dict_unpacking_no_extra_comma() {
let test_code = r#"
def func(**kwargs):
pass
d = {"key": "value"}
func(**d)
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.func".to_string(),
create_replacement_info("func", "new_func({**kwargs})", vec!["**kwargs"]),
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("new_func(**d)"));
assert!(!result.contains("new_func(, **d)")); }
#[test]
fn test_method_call_on_variable_repo() {
let test_code = r#"
r = BaseRepo()
r.do_commit(b"Test commit", author=b"Test Author <test@example.com>")
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.BaseRepo.do_commit".to_string(),
create_replacement_info(
"do_commit",
"{self}.get_worktree().commit(message={message}, {**kwargs})",
vec!["self", "message", "**kwargs"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
assert!(result.contains("r.get_worktree().commit("));
assert!(
result.contains("message=b\"Test")
&& (result.contains("commit\"") || result.contains("commit\""))
);
assert!(result.contains("author=b\"Test") && result.contains("example.com"));
}
#[test]
fn test_import_replacement_function() {
let test_code = r#"
# Import at module level
from test_module import checkout_branch
def test_module_import():
# Module-qualified call should be replaced with FQN
test_module.checkout_branch(repo, "main")
def test_direct_call():
# Direct call without module prefix
checkout_branch(repo, "feature")
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.checkout_branch".to_string(),
create_replacement_info(
"checkout_branch",
"test_module.checkout({repo}, {target})",
vec!["repo", "target"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
println!("test_import_replacement_function debug:");
println!("{}", result);
assert!(result.contains("from test_module import checkout_branch"));
assert!(result.contains("test_module.checkout(repo, \"main\")"));
assert!(result.contains("test_module.checkout(repo, \"feature\")"));
}
#[test]
fn test_no_migration_without_type_info() {
let source = r#"
def test_unknown_type():
# entry type is unknown - we don't know if it's IndexEntry or something else
stage_num = entry.stage()
"#;
let mut replacements = HashMap::new();
replacements.insert(
"test_module.Repo.stage".to_string(),
create_replacement_info(
"stage",
"{self}.get_worktree().stage({fs_paths})",
vec!["self", "fs_paths"],
),
);
let result = migrate_source_with_replacements(source, replacements);
assert!(result.contains("entry.stage()"));
assert!(!result.contains("get_worktree()"));
}
#[test]
fn test_method_on_known_type() {
let test_code = r#"
def test_repo_stage():
repo = Repo.init(".")
repo.stage(["file.txt"])
"#;
let source = create_test_with_base_classes(test_code);
let mut replacements = HashMap::new();
replacements.insert(
"test_module.Repo.stage".to_string(),
create_replacement_info(
"stage",
"{self}.get_worktree().stage({fs_paths})",
vec!["self", "fs_paths"],
),
);
let result = migrate_source_with_replacements(&source, replacements);
println!("Test source:\n{}", source);
println!("\nMigration result:\n{}", result);
assert!(result.contains("repo.get_worktree().stage([\"file.txt\"])"));
}
}