Skip to main content

pytest_language_server/providers/
rename.rs

1//! Rename provider for `@pytest.mark.parametrize` parameters.
2//!
3//! Renaming a parametrized parameter rewrites, in one edit, the name token inside the
4//! `@pytest.mark.parametrize(...)` decorator string, the matching function-signature parameter,
5//! and every usage of that parameter in the function body.  The rename can be triggered from any
6//! of those three sites.
7//!
8//! Only parametrize parameters are handled; for any other symbol the request returns `None` so a
9//! general Python language server can answer it.
10
11use super::Backend;
12use crate::fixtures::{decorators, FixtureDatabase};
13use rustpython_parser::ast::{
14    Arguments, Expr, ExprDictComp, ExprGeneratorExp, ExprLambda, ExprListComp, ExprName,
15    ExprSetComp, Ranged, Stmt, StmtAsyncFunctionDef, StmtFunctionDef, Visitor,
16};
17use rustpython_parser::text_size::TextRange;
18use rustpython_parser::{parse, Mode};
19use std::collections::{HashMap, HashSet};
20use tower_lsp_server::jsonrpc::{Error, Result};
21use tower_lsp_server::ls_types::*;
22use tracing::info;
23
24const PYTHON_KEYWORDS: &[&str] = &[
25    "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue",
26    "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import",
27    "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while",
28    "with", "yield",
29];
30
31/// All occurrences of a parametrize parameter within one test function that must change together.
32struct RenameTarget {
33    /// LSP range of the token under the cursor (for the prepareRename response).
34    cursor_token: Range,
35    /// Every editable occurrence: decorator name token(s), signature parameter, body usages.
36    edits: Vec<Range>,
37}
38
39/// A function definition with the parts needed for parametrize rename, borrowed from the AST.
40struct FuncCtx<'a> {
41    decorators: &'a [Expr],
42    args: &'a Arguments,
43    body: &'a [Stmt],
44    range: TextRange,
45}
46
47impl FuncCtx<'_> {
48    /// Source span covering the decorators and the `def` body, used to locate the cursor.
49    /// (`FunctionDef.range` starts at `def`, so decorators must be folded in explicitly.)
50    fn bounds(&self) -> (usize, usize) {
51        let mut start = self.range.start().to_usize();
52        for dec in self.decorators {
53            start = start.min(dec.range().start().to_usize());
54        }
55        (start, self.range.end().to_usize())
56    }
57
58    fn contains(&self, offset: usize) -> bool {
59        let (start, end) = self.bounds();
60        start <= offset && offset <= end
61    }
62
63    fn span(&self) -> usize {
64        let (start, end) = self.bounds();
65        end - start
66    }
67}
68
69/// Collects the ranges of every `Name` expression that refers to a target parameter, walking the
70/// function body via the generated `Visitor`.
71///
72/// It is scope-aware: a nested function/lambda whose parameters shadow the target, and a
73/// comprehension whose loop target shadows it, bind a *different* variable, so their inner bodies
74/// are not collected. Parts evaluated in the enclosing scope (decorators, parameter defaults and
75/// annotations, the first comprehension iterable) are still visited.
76///
77/// Limitation: a nested function that rebinds the name by assignment, `global`, or `nonlocal`
78/// (rather than by parameter) is not detected; that case over-collects. It does not occur in
79/// practice in test bodies.
80struct NameUsageCollector {
81    target: String,
82    ranges: Vec<TextRange>,
83}
84
85impl NameUsageCollector {
86    /// Visit parameter defaults and annotations, which are evaluated in the enclosing scope.
87    fn visit_arg_context(&mut self, args: &Arguments) {
88        for arg in args
89            .posonlyargs
90            .iter()
91            .chain(&args.args)
92            .chain(&args.kwonlyargs)
93        {
94            if let Some(default) = &arg.default {
95                self.visit_expr((**default).clone());
96            }
97            if let Some(annotation) = &arg.def.annotation {
98                self.visit_expr((**annotation).clone());
99            }
100        }
101        if let Some(va) = &args.vararg {
102            if let Some(annotation) = &va.annotation {
103                self.visit_expr((**annotation).clone());
104            }
105        }
106        if let Some(kw) = &args.kwarg {
107            if let Some(annotation) = &kw.annotation {
108                self.visit_expr((**annotation).clone());
109            }
110        }
111    }
112
113    fn visit_comprehension(
114        &mut self,
115        elements: Vec<Expr>,
116        generators: Vec<rustpython_parser::ast::Comprehension>,
117    ) {
118        let shadows = generators
119            .iter()
120            .any(|g| expr_binds_name(&g.target, &self.target));
121
122        for (i, generator) in generators.into_iter().enumerate() {
123            // The first generator's iterable is evaluated in the enclosing scope.
124            if i == 0 || !shadows {
125                self.visit_expr(generator.iter);
126            }
127            if !shadows {
128                for cond in generator.ifs {
129                    self.visit_expr(cond);
130                }
131            }
132        }
133        if !shadows {
134            for element in elements {
135                self.visit_expr(element);
136            }
137        }
138    }
139}
140
141impl Visitor for NameUsageCollector {
142    fn visit_expr_name(&mut self, node: ExprName) {
143        if node.id.as_str() == self.target {
144            self.ranges.push(node.range);
145        }
146    }
147
148    fn visit_stmt_function_def(&mut self, node: StmtFunctionDef) {
149        for decorator in node.decorator_list {
150            self.visit_expr(decorator);
151        }
152        self.visit_arg_context(&node.args);
153        if let Some(returns) = node.returns {
154            self.visit_expr(*returns);
155        }
156        if !args_bind(&node.args, &self.target) {
157            for stmt in node.body {
158                self.visit_stmt(stmt);
159            }
160        }
161    }
162
163    fn visit_stmt_async_function_def(&mut self, node: StmtAsyncFunctionDef) {
164        for decorator in node.decorator_list {
165            self.visit_expr(decorator);
166        }
167        self.visit_arg_context(&node.args);
168        if let Some(returns) = node.returns {
169            self.visit_expr(*returns);
170        }
171        if !args_bind(&node.args, &self.target) {
172            for stmt in node.body {
173                self.visit_stmt(stmt);
174            }
175        }
176    }
177
178    fn visit_expr_lambda(&mut self, node: ExprLambda) {
179        self.visit_arg_context(&node.args);
180        if !args_bind(&node.args, &self.target) {
181            self.visit_expr(*node.body);
182        }
183    }
184
185    fn visit_expr_list_comp(&mut self, node: ExprListComp) {
186        self.visit_comprehension(vec![*node.elt], node.generators);
187    }
188
189    fn visit_expr_set_comp(&mut self, node: ExprSetComp) {
190        self.visit_comprehension(vec![*node.elt], node.generators);
191    }
192
193    fn visit_expr_generator_exp(&mut self, node: ExprGeneratorExp) {
194        self.visit_comprehension(vec![*node.elt], node.generators);
195    }
196
197    fn visit_expr_dict_comp(&mut self, node: ExprDictComp) {
198        self.visit_comprehension(vec![*node.key, *node.value], node.generators);
199    }
200}
201
202/// Whether any parameter of `args` is named `target`.
203fn args_bind(args: &Arguments, target: &str) -> bool {
204    args.posonlyargs
205        .iter()
206        .chain(&args.args)
207        .chain(&args.kwonlyargs)
208        .any(|arg| arg.def.arg.as_str() == target)
209        || args
210            .vararg
211            .as_ref()
212            .is_some_and(|a| a.arg.as_str() == target)
213        || args
214            .kwarg
215            .as_ref()
216            .is_some_and(|a| a.arg.as_str() == target)
217}
218
219/// Whether an assignment/comprehension target binds `name` (handles tuple/list/star unpacking).
220fn expr_binds_name(target: &Expr, name: &str) -> bool {
221    match target {
222        Expr::Name(n) => n.id.as_str() == name,
223        Expr::Tuple(t) => t.elts.iter().any(|e| expr_binds_name(e, name)),
224        Expr::List(l) => l.elts.iter().any(|e| expr_binds_name(e, name)),
225        Expr::Starred(s) => expr_binds_name(&s.value, name),
226        _ => false,
227    }
228}
229
230impl Backend {
231    /// Handle a `textDocument/prepareRename` request.
232    pub async fn handle_prepare_rename(
233        &self,
234        params: TextDocumentPositionParams,
235    ) -> Result<Option<PrepareRenameResponse>> {
236        let uri = params.text_document.uri;
237        let position = params.position;
238
239        let Some(file_path) = self.uri_to_path(&uri) else {
240            return Ok(None);
241        };
242        let Some(content) = self.fixture_db.get_file_content(&file_path) else {
243            return Ok(None);
244        };
245
246        Ok(self
247            .parametrize_rename_target(&content, position)
248            .map(|target| PrepareRenameResponse::Range(target.cursor_token)))
249    }
250
251    /// Handle a `textDocument/rename` request.
252    pub async fn handle_rename(&self, params: RenameParams) -> Result<Option<WorkspaceEdit>> {
253        let uri = params.text_document_position.text_document.uri;
254        let position = params.text_document_position.position;
255        let new_name = params.new_name;
256
257        let Some(file_path) = self.uri_to_path(&uri) else {
258            return Ok(None);
259        };
260        let Some(content) = self.fixture_db.get_file_content(&file_path) else {
261            return Ok(None);
262        };
263
264        let Some(target) = self.parametrize_rename_target(&content, position) else {
265            return Ok(None);
266        };
267
268        if !is_valid_python_identifier(&new_name) {
269            return Err(Error::invalid_params(format!(
270                "'{new_name}' is not a valid Python identifier"
271            )));
272        }
273
274        info!(
275            "rename: {} occurrence(s) of parametrize param -> '{}'",
276            target.edits.len(),
277            new_name
278        );
279
280        let edits: Vec<TextEdit> = target
281            .edits
282            .into_iter()
283            .map(|range| TextEdit {
284                range,
285                new_text: new_name.clone(),
286            })
287            .collect();
288
289        let mut changes = HashMap::new();
290        changes.insert(uri, edits);
291
292        Ok(Some(WorkspaceEdit {
293            changes: Some(changes),
294            document_changes: None,
295            change_annotations: None,
296        }))
297    }
298
299    /// Resolve the parametrize parameter at `position` and gather all of its occurrences.
300    fn parametrize_rename_target(&self, content: &str, position: Position) -> Option<RenameTarget> {
301        let rustpython_parser::ast::Mod::Module(module) = parse(content, Mode::Module, "").ok()?
302        else {
303            return None;
304        };
305
306        let line_index = FixtureDatabase::build_line_index(content);
307        let cursor_offset = *line_index.get(position.line as usize)? + position.character as usize;
308
309        // Innermost *parametrized* function whose decorators or body contain the cursor. Filtering
310        // to parametrized functions means a cursor inside a nested closure that references the
311        // parameter still resolves to the enclosing parametrized test rather than the closure.
312        let mut functions = Vec::new();
313        collect_functions(&module.body, &mut functions);
314        let func = functions
315            .into_iter()
316            .filter(|f| f.contains(cursor_offset))
317            .filter(|f| {
318                f.decorators
319                    .iter()
320                    .any(|d| !decorators::extract_parametrize_argnames(d, content).is_empty())
321            })
322            .min_by_key(FuncCtx::span)?;
323
324        // Parametrize names declared across all decorators, excluding indirect ones (those route
325        // to a fixture, so a local-only rename would silently break the test).
326        let mut name_to_decorator_ranges: HashMap<String, Vec<TextRange>> = HashMap::new();
327        for dec in func.decorators {
328            let argnames = decorators::extract_parametrize_argnames(dec, content);
329            let names: Vec<String> = argnames.iter().map(|(name, _)| name.clone()).collect();
330            let indirect = decorators::extract_parametrize_indirect_names(dec, &names);
331            for (name, range) in argnames {
332                if indirect.contains(&name) {
333                    continue;
334                }
335                name_to_decorator_ranges
336                    .entry(name)
337                    .or_default()
338                    .push(range);
339            }
340        }
341        if name_to_decorator_ranges.is_empty() {
342            return None;
343        }
344
345        // Signature parameter names, used to confirm the cursor sits on a real parameter.
346        let signature_params: HashSet<&str> = FixtureDatabase::all_args(func.args)
347            .map(|arg| arg.def.arg.as_str())
348            .collect();
349
350        // Determine the target name from whichever site the cursor is on.
351        let target_name = name_to_decorator_ranges
352            .iter()
353            .find(|(_, ranges)| ranges.iter().any(|r| range_contains(r, cursor_offset)))
354            .map(|(name, _)| name.clone())
355            .or_else(|| {
356                let word = identifier_at(content, cursor_offset)?;
357                (name_to_decorator_ranges.contains_key(&word)
358                    && signature_params.contains(word.as_str()))
359                .then_some(word)
360            })?;
361
362        // Gather every occurrence to edit.
363        let mut occurrences: Vec<TextRange> = Vec::new();
364        occurrences.extend(
365            name_to_decorator_ranges
366                .remove(&target_name)
367                .into_iter()
368                .flatten(),
369        );
370
371        if let Some(arg) =
372            FixtureDatabase::all_args(func.args).find(|arg| arg.def.arg.as_str() == target_name)
373        {
374            let start = arg.def.range.start();
375            occurrences.push(TextRange::new(
376                start,
377                start + rustpython_parser::text_size::TextSize::from(target_name.len() as u32),
378            ));
379        }
380
381        let mut collector = NameUsageCollector {
382            target: target_name.clone(),
383            ranges: Vec::new(),
384        };
385        for stmt in func.body {
386            collector.visit_stmt(stmt.clone());
387        }
388        occurrences.extend(collector.ranges);
389
390        occurrences.sort_by_key(|r| (r.start().to_usize(), r.end().to_usize()));
391        occurrences.dedup();
392
393        let cursor_tr = occurrences
394            .iter()
395            .find(|r| range_contains(r, cursor_offset))
396            .copied()
397            .unwrap_or(occurrences[0]);
398
399        let to_lsp = |tr: &TextRange| self.text_range_to_lsp(tr, &line_index);
400        Some(RenameTarget {
401            cursor_token: to_lsp(&cursor_tr),
402            edits: occurrences.iter().map(to_lsp).collect(),
403        })
404    }
405
406    /// Convert a source [`TextRange`] into an LSP [`Range`] using the file's line index.
407    fn text_range_to_lsp(&self, tr: &TextRange, line_index: &[usize]) -> Range {
408        let start_offset = tr.start().to_usize();
409        let end_offset = tr.end().to_usize();
410        let start_line = self
411            .fixture_db
412            .get_line_from_offset(start_offset, line_index);
413        let end_line = self.fixture_db.get_line_from_offset(end_offset, line_index);
414        Range {
415            start: Position {
416                line: (start_line - 1) as u32,
417                character: self
418                    .fixture_db
419                    .get_char_position_from_offset(start_offset, line_index)
420                    as u32,
421            },
422            end: Position {
423                line: (end_line - 1) as u32,
424                character: self
425                    .fixture_db
426                    .get_char_position_from_offset(end_offset, line_index)
427                    as u32,
428            },
429        }
430    }
431}
432
433fn range_contains(range: &TextRange, offset: usize) -> bool {
434    range.start().to_usize() <= offset && offset <= range.end().to_usize()
435}
436
437/// Returns the ASCII identifier spanning `offset` in `content`, treating `offset` inclusively so
438/// a caret resting just past the last character (a common rename position) still resolves.
439///
440/// Works in byte offsets to stay consistent with the rest of this provider; identifiers are ASCII
441/// so this never splits a multi-byte character.
442fn identifier_at(content: &str, offset: usize) -> Option<String> {
443    let bytes = content.as_bytes();
444    if offset > bytes.len() {
445        return None;
446    }
447    let is_word = |b: u8| b == b'_' || b.is_ascii_alphanumeric();
448
449    let mut start = offset;
450    while start > 0 && is_word(bytes[start - 1]) {
451        start -= 1;
452    }
453    let mut end = offset;
454    while end < bytes.len() && is_word(bytes[end]) {
455        end += 1;
456    }
457    if start == end {
458        return None;
459    }
460    Some(content[start..end].to_string())
461}
462
463/// Recursively collect every function definition, descending into classes and nested functions.
464fn collect_functions<'a>(stmts: &'a [Stmt], out: &mut Vec<FuncCtx<'a>>) {
465    for stmt in stmts {
466        match stmt {
467            Stmt::FunctionDef(f) => {
468                out.push(FuncCtx {
469                    decorators: &f.decorator_list,
470                    args: &f.args,
471                    body: &f.body,
472                    range: f.range,
473                });
474                collect_functions(&f.body, out);
475            }
476            Stmt::AsyncFunctionDef(f) => {
477                out.push(FuncCtx {
478                    decorators: &f.decorator_list,
479                    args: &f.args,
480                    body: &f.body,
481                    range: f.range,
482                });
483                collect_functions(&f.body, out);
484            }
485            Stmt::ClassDef(c) => collect_functions(&c.body, out),
486            _ => {}
487        }
488    }
489}
490
491fn is_valid_python_identifier(name: &str) -> bool {
492    let mut chars = name.chars();
493    match chars.next() {
494        Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
495        _ => return false,
496    }
497    if !chars.all(|c| c == '_' || c.is_ascii_alphanumeric()) {
498        return false;
499    }
500    !PYTHON_KEYWORDS.contains(&name)
501}