Skip to main content

ruff_python_ast/
traversal.rs

1//! Utilities for manually traversing a Python AST.
2use crate::{self as ast, AnyNodeRef, ExceptHandler, Stmt};
3
4/// Given a [`Stmt`] and its parent, return the [`ast::Suite`] that contains the [`Stmt`].
5pub fn suite<'a>(
6    stmt: impl Into<AnyNodeRef<'a>>,
7    parent: impl Into<AnyNodeRef<'a>>,
8) -> Option<EnclosingSuite<'a>> {
9    // TODO: refactor this to work without a parent, ie when `stmt` is at the top level
10    let stmt = stmt.into();
11    match parent.into() {
12        AnyNodeRef::ModModule(ast::ModModule { body, .. }) => EnclosingSuite::new(body, stmt),
13        AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) => {
14            EnclosingSuite::new(body, stmt)
15        }
16        AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) => EnclosingSuite::new(body, stmt),
17        AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. }) => [body, orelse]
18            .iter()
19            .find_map(|suite| EnclosingSuite::new(suite, stmt)),
20        AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => [body, orelse]
21            .iter()
22            .find_map(|suite| EnclosingSuite::new(suite, stmt)),
23        AnyNodeRef::StmtIf(ast::StmtIf {
24            body,
25            elif_else_clauses,
26            ..
27        }) => [body]
28            .into_iter()
29            .chain(elif_else_clauses.iter().map(|clause| &clause.body))
30            .find_map(|suite| EnclosingSuite::new(suite, stmt)),
31        AnyNodeRef::StmtWith(ast::StmtWith { body, .. }) => EnclosingSuite::new(body, stmt),
32        AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => cases
33            .iter()
34            .map(|case| &case.body)
35            .find_map(|body| EnclosingSuite::new(body, stmt)),
36        AnyNodeRef::StmtTry(ast::StmtTry {
37            body,
38            handlers,
39            orelse,
40            finalbody,
41            ..
42        }) => [body, orelse, finalbody]
43            .into_iter()
44            .chain(
45                handlers
46                    .iter()
47                    .filter_map(ExceptHandler::as_except_handler)
48                    .map(|handler| &handler.body),
49            )
50            .find_map(|suite| EnclosingSuite::new(suite, stmt)),
51        _ => None,
52    }
53}
54
55pub struct EnclosingSuite<'a> {
56    suite: &'a [Stmt],
57    position: usize,
58}
59
60impl<'a> EnclosingSuite<'a> {
61    pub fn new(suite: &'a [Stmt], stmt: AnyNodeRef<'a>) -> Option<Self> {
62        let position = suite
63            .iter()
64            .position(|sibling| AnyNodeRef::ptr_eq(sibling.into(), stmt))?;
65
66        Some(EnclosingSuite { suite, position })
67    }
68
69    pub fn next_sibling(&self) -> Option<&'a Stmt> {
70        self.suite.get(self.position + 1)
71    }
72
73    pub fn next_siblings(&self) -> &'a [Stmt] {
74        self.suite.get(self.position + 1..).unwrap_or_default()
75    }
76
77    pub fn previous_sibling(&self) -> Option<&'a Stmt> {
78        self.suite.get(self.position.checked_sub(1)?)
79    }
80}
81
82impl std::ops::Deref for EnclosingSuite<'_> {
83    type Target = [Stmt];
84
85    fn deref(&self) -> &Self::Target {
86        self.suite
87    }
88}