pytest_language_server/providers/
rename.rs1use super::Backend;
12use crate::fixtures::{decorators, FixtureDatabase};
13use rustpython_parser::ast::{
14 Arguments, Expr, ExprDictComp, ExprGeneratorExp, ExprLambda, ExprListComp, ExprName,
15 ExprSetComp, Ranged, Stmt, StmtAsyncFunctionDef, StmtFunctionDef, Visitor,
16};
17use rustpython_parser::text_size::TextRange;
18use rustpython_parser::{parse, Mode};
19use std::collections::{HashMap, HashSet};
20use tower_lsp_server::jsonrpc::{Error, Result};
21use tower_lsp_server::ls_types::*;
22use tracing::info;
23
24const PYTHON_KEYWORDS: &[&str] = &[
25 "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue",
26 "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import",
27 "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while",
28 "with", "yield",
29];
30
31struct RenameTarget {
33 cursor_token: Range,
35 edits: Vec<Range>,
37}
38
39struct FuncCtx<'a> {
41 decorators: &'a [Expr],
42 args: &'a Arguments,
43 body: &'a [Stmt],
44 range: TextRange,
45}
46
47impl FuncCtx<'_> {
48 fn bounds(&self) -> (usize, usize) {
51 let mut start = self.range.start().to_usize();
52 for dec in self.decorators {
53 start = start.min(dec.range().start().to_usize());
54 }
55 (start, self.range.end().to_usize())
56 }
57
58 fn contains(&self, offset: usize) -> bool {
59 let (start, end) = self.bounds();
60 start <= offset && offset <= end
61 }
62
63 fn span(&self) -> usize {
64 let (start, end) = self.bounds();
65 end - start
66 }
67}
68
69struct NameUsageCollector {
81 target: String,
82 ranges: Vec<TextRange>,
83}
84
85impl NameUsageCollector {
86 fn visit_arg_context(&mut self, args: &Arguments) {
88 for arg in args
89 .posonlyargs
90 .iter()
91 .chain(&args.args)
92 .chain(&args.kwonlyargs)
93 {
94 if let Some(default) = &arg.default {
95 self.visit_expr((**default).clone());
96 }
97 if let Some(annotation) = &arg.def.annotation {
98 self.visit_expr((**annotation).clone());
99 }
100 }
101 if let Some(va) = &args.vararg {
102 if let Some(annotation) = &va.annotation {
103 self.visit_expr((**annotation).clone());
104 }
105 }
106 if let Some(kw) = &args.kwarg {
107 if let Some(annotation) = &kw.annotation {
108 self.visit_expr((**annotation).clone());
109 }
110 }
111 }
112
113 fn visit_comprehension(
114 &mut self,
115 elements: Vec<Expr>,
116 generators: Vec<rustpython_parser::ast::Comprehension>,
117 ) {
118 let shadows = generators
119 .iter()
120 .any(|g| expr_binds_name(&g.target, &self.target));
121
122 for (i, generator) in generators.into_iter().enumerate() {
123 if i == 0 || !shadows {
125 self.visit_expr(generator.iter);
126 }
127 if !shadows {
128 for cond in generator.ifs {
129 self.visit_expr(cond);
130 }
131 }
132 }
133 if !shadows {
134 for element in elements {
135 self.visit_expr(element);
136 }
137 }
138 }
139}
140
141impl Visitor for NameUsageCollector {
142 fn visit_expr_name(&mut self, node: ExprName) {
143 if node.id.as_str() == self.target {
144 self.ranges.push(node.range);
145 }
146 }
147
148 fn visit_stmt_function_def(&mut self, node: StmtFunctionDef) {
149 for decorator in node.decorator_list {
150 self.visit_expr(decorator);
151 }
152 self.visit_arg_context(&node.args);
153 if let Some(returns) = node.returns {
154 self.visit_expr(*returns);
155 }
156 if !args_bind(&node.args, &self.target) {
157 for stmt in node.body {
158 self.visit_stmt(stmt);
159 }
160 }
161 }
162
163 fn visit_stmt_async_function_def(&mut self, node: StmtAsyncFunctionDef) {
164 for decorator in node.decorator_list {
165 self.visit_expr(decorator);
166 }
167 self.visit_arg_context(&node.args);
168 if let Some(returns) = node.returns {
169 self.visit_expr(*returns);
170 }
171 if !args_bind(&node.args, &self.target) {
172 for stmt in node.body {
173 self.visit_stmt(stmt);
174 }
175 }
176 }
177
178 fn visit_expr_lambda(&mut self, node: ExprLambda) {
179 self.visit_arg_context(&node.args);
180 if !args_bind(&node.args, &self.target) {
181 self.visit_expr(*node.body);
182 }
183 }
184
185 fn visit_expr_list_comp(&mut self, node: ExprListComp) {
186 self.visit_comprehension(vec![*node.elt], node.generators);
187 }
188
189 fn visit_expr_set_comp(&mut self, node: ExprSetComp) {
190 self.visit_comprehension(vec![*node.elt], node.generators);
191 }
192
193 fn visit_expr_generator_exp(&mut self, node: ExprGeneratorExp) {
194 self.visit_comprehension(vec![*node.elt], node.generators);
195 }
196
197 fn visit_expr_dict_comp(&mut self, node: ExprDictComp) {
198 self.visit_comprehension(vec![*node.key, *node.value], node.generators);
199 }
200}
201
202fn args_bind(args: &Arguments, target: &str) -> bool {
204 args.posonlyargs
205 .iter()
206 .chain(&args.args)
207 .chain(&args.kwonlyargs)
208 .any(|arg| arg.def.arg.as_str() == target)
209 || args
210 .vararg
211 .as_ref()
212 .is_some_and(|a| a.arg.as_str() == target)
213 || args
214 .kwarg
215 .as_ref()
216 .is_some_and(|a| a.arg.as_str() == target)
217}
218
219fn expr_binds_name(target: &Expr, name: &str) -> bool {
221 match target {
222 Expr::Name(n) => n.id.as_str() == name,
223 Expr::Tuple(t) => t.elts.iter().any(|e| expr_binds_name(e, name)),
224 Expr::List(l) => l.elts.iter().any(|e| expr_binds_name(e, name)),
225 Expr::Starred(s) => expr_binds_name(&s.value, name),
226 _ => false,
227 }
228}
229
230impl Backend {
231 pub async fn handle_prepare_rename(
233 &self,
234 params: TextDocumentPositionParams,
235 ) -> Result<Option<PrepareRenameResponse>> {
236 let uri = params.text_document.uri;
237 let position = params.position;
238
239 let Some(file_path) = self.uri_to_path(&uri) else {
240 return Ok(None);
241 };
242 let Some(content) = self.fixture_db.get_file_content(&file_path) else {
243 return Ok(None);
244 };
245
246 Ok(self
247 .parametrize_rename_target(&content, position)
248 .map(|target| PrepareRenameResponse::Range(target.cursor_token)))
249 }
250
251 pub async fn handle_rename(&self, params: RenameParams) -> Result<Option<WorkspaceEdit>> {
253 let uri = params.text_document_position.text_document.uri;
254 let position = params.text_document_position.position;
255 let new_name = params.new_name;
256
257 let Some(file_path) = self.uri_to_path(&uri) else {
258 return Ok(None);
259 };
260 let Some(content) = self.fixture_db.get_file_content(&file_path) else {
261 return Ok(None);
262 };
263
264 let Some(target) = self.parametrize_rename_target(&content, position) else {
265 return Ok(None);
266 };
267
268 if !is_valid_python_identifier(&new_name) {
269 return Err(Error::invalid_params(format!(
270 "'{new_name}' is not a valid Python identifier"
271 )));
272 }
273
274 info!(
275 "rename: {} occurrence(s) of parametrize param -> '{}'",
276 target.edits.len(),
277 new_name
278 );
279
280 let edits: Vec<TextEdit> = target
281 .edits
282 .into_iter()
283 .map(|range| TextEdit {
284 range,
285 new_text: new_name.clone(),
286 })
287 .collect();
288
289 let mut changes = HashMap::new();
290 changes.insert(uri, edits);
291
292 Ok(Some(WorkspaceEdit {
293 changes: Some(changes),
294 document_changes: None,
295 change_annotations: None,
296 }))
297 }
298
299 fn parametrize_rename_target(&self, content: &str, position: Position) -> Option<RenameTarget> {
301 let rustpython_parser::ast::Mod::Module(module) = parse(content, Mode::Module, "").ok()?
302 else {
303 return None;
304 };
305
306 let line_index = FixtureDatabase::build_line_index(content);
307 let cursor_offset = *line_index.get(position.line as usize)? + position.character as usize;
308
309 let mut functions = Vec::new();
313 collect_functions(&module.body, &mut functions);
314 let func = functions
315 .into_iter()
316 .filter(|f| f.contains(cursor_offset))
317 .filter(|f| {
318 f.decorators
319 .iter()
320 .any(|d| !decorators::extract_parametrize_argnames(d, content).is_empty())
321 })
322 .min_by_key(FuncCtx::span)?;
323
324 let mut name_to_decorator_ranges: HashMap<String, Vec<TextRange>> = HashMap::new();
327 for dec in func.decorators {
328 let argnames = decorators::extract_parametrize_argnames(dec, content);
329 let names: Vec<String> = argnames.iter().map(|(name, _)| name.clone()).collect();
330 let indirect = decorators::extract_parametrize_indirect_names(dec, &names);
331 for (name, range) in argnames {
332 if indirect.contains(&name) {
333 continue;
334 }
335 name_to_decorator_ranges
336 .entry(name)
337 .or_default()
338 .push(range);
339 }
340 }
341 if name_to_decorator_ranges.is_empty() {
342 return None;
343 }
344
345 let signature_params: HashSet<&str> = FixtureDatabase::all_args(func.args)
347 .map(|arg| arg.def.arg.as_str())
348 .collect();
349
350 let target_name = name_to_decorator_ranges
352 .iter()
353 .find(|(_, ranges)| ranges.iter().any(|r| range_contains(r, cursor_offset)))
354 .map(|(name, _)| name.clone())
355 .or_else(|| {
356 let word = identifier_at(content, cursor_offset)?;
357 (name_to_decorator_ranges.contains_key(&word)
358 && signature_params.contains(word.as_str()))
359 .then_some(word)
360 })?;
361
362 let mut occurrences: Vec<TextRange> = Vec::new();
364 occurrences.extend(
365 name_to_decorator_ranges
366 .remove(&target_name)
367 .into_iter()
368 .flatten(),
369 );
370
371 if let Some(arg) =
372 FixtureDatabase::all_args(func.args).find(|arg| arg.def.arg.as_str() == target_name)
373 {
374 let start = arg.def.range.start();
375 occurrences.push(TextRange::new(
376 start,
377 start + rustpython_parser::text_size::TextSize::from(target_name.len() as u32),
378 ));
379 }
380
381 let mut collector = NameUsageCollector {
382 target: target_name.clone(),
383 ranges: Vec::new(),
384 };
385 for stmt in func.body {
386 collector.visit_stmt(stmt.clone());
387 }
388 occurrences.extend(collector.ranges);
389
390 occurrences.sort_by_key(|r| (r.start().to_usize(), r.end().to_usize()));
391 occurrences.dedup();
392
393 let cursor_tr = occurrences
394 .iter()
395 .find(|r| range_contains(r, cursor_offset))
396 .copied()
397 .unwrap_or(occurrences[0]);
398
399 let to_lsp = |tr: &TextRange| self.text_range_to_lsp(tr, &line_index);
400 Some(RenameTarget {
401 cursor_token: to_lsp(&cursor_tr),
402 edits: occurrences.iter().map(to_lsp).collect(),
403 })
404 }
405
406 fn text_range_to_lsp(&self, tr: &TextRange, line_index: &[usize]) -> Range {
408 let start_offset = tr.start().to_usize();
409 let end_offset = tr.end().to_usize();
410 let start_line = self
411 .fixture_db
412 .get_line_from_offset(start_offset, line_index);
413 let end_line = self.fixture_db.get_line_from_offset(end_offset, line_index);
414 Range {
415 start: Position {
416 line: (start_line - 1) as u32,
417 character: self
418 .fixture_db
419 .get_char_position_from_offset(start_offset, line_index)
420 as u32,
421 },
422 end: Position {
423 line: (end_line - 1) as u32,
424 character: self
425 .fixture_db
426 .get_char_position_from_offset(end_offset, line_index)
427 as u32,
428 },
429 }
430 }
431}
432
433fn range_contains(range: &TextRange, offset: usize) -> bool {
434 range.start().to_usize() <= offset && offset <= range.end().to_usize()
435}
436
437fn identifier_at(content: &str, offset: usize) -> Option<String> {
443 let bytes = content.as_bytes();
444 if offset > bytes.len() {
445 return None;
446 }
447 let is_word = |b: u8| b == b'_' || b.is_ascii_alphanumeric();
448
449 let mut start = offset;
450 while start > 0 && is_word(bytes[start - 1]) {
451 start -= 1;
452 }
453 let mut end = offset;
454 while end < bytes.len() && is_word(bytes[end]) {
455 end += 1;
456 }
457 if start == end {
458 return None;
459 }
460 Some(content[start..end].to_string())
461}
462
463fn collect_functions<'a>(stmts: &'a [Stmt], out: &mut Vec<FuncCtx<'a>>) {
465 for stmt in stmts {
466 match stmt {
467 Stmt::FunctionDef(f) => {
468 out.push(FuncCtx {
469 decorators: &f.decorator_list,
470 args: &f.args,
471 body: &f.body,
472 range: f.range,
473 });
474 collect_functions(&f.body, out);
475 }
476 Stmt::AsyncFunctionDef(f) => {
477 out.push(FuncCtx {
478 decorators: &f.decorator_list,
479 args: &f.args,
480 body: &f.body,
481 range: f.range,
482 });
483 collect_functions(&f.body, out);
484 }
485 Stmt::ClassDef(c) => collect_functions(&c.body, out),
486 _ => {}
487 }
488 }
489}
490
491fn is_valid_python_identifier(name: &str) -> bool {
492 let mut chars = name.chars();
493 match chars.next() {
494 Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
495 _ => return false,
496 }
497 if !chars.all(|c| c == '_' || c.is_ascii_alphanumeric()) {
498 return false;
499 }
500 !PYTHON_KEYWORDS.contains(&name)
501}