ruff_python_ast/
traversal.rs1use crate::{self as ast, AnyNodeRef, ExceptHandler, Stmt};
3
4pub fn suite<'a>(
6 stmt: impl Into<AnyNodeRef<'a>>,
7 parent: impl Into<AnyNodeRef<'a>>,
8) -> Option<EnclosingSuite<'a>> {
9 let stmt = stmt.into();
11 match parent.into() {
12 AnyNodeRef::ModModule(ast::ModModule { body, .. }) => EnclosingSuite::new(body, stmt),
13 AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) => {
14 EnclosingSuite::new(body, stmt)
15 }
16 AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) => EnclosingSuite::new(body, stmt),
17 AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. }) => [body, orelse]
18 .iter()
19 .find_map(|suite| EnclosingSuite::new(suite, stmt)),
20 AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => [body, orelse]
21 .iter()
22 .find_map(|suite| EnclosingSuite::new(suite, stmt)),
23 AnyNodeRef::StmtIf(ast::StmtIf {
24 body,
25 elif_else_clauses,
26 ..
27 }) => [body]
28 .into_iter()
29 .chain(elif_else_clauses.iter().map(|clause| &clause.body))
30 .find_map(|suite| EnclosingSuite::new(suite, stmt)),
31 AnyNodeRef::StmtWith(ast::StmtWith { body, .. }) => EnclosingSuite::new(body, stmt),
32 AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => cases
33 .iter()
34 .map(|case| &case.body)
35 .find_map(|body| EnclosingSuite::new(body, stmt)),
36 AnyNodeRef::StmtTry(ast::StmtTry {
37 body,
38 handlers,
39 orelse,
40 finalbody,
41 ..
42 }) => [body, orelse, finalbody]
43 .into_iter()
44 .chain(
45 handlers
46 .iter()
47 .filter_map(ExceptHandler::as_except_handler)
48 .map(|handler| &handler.body),
49 )
50 .find_map(|suite| EnclosingSuite::new(suite, stmt)),
51 _ => None,
52 }
53}
54
55pub struct EnclosingSuite<'a> {
56 suite: &'a [Stmt],
57 position: usize,
58}
59
60impl<'a> EnclosingSuite<'a> {
61 pub fn new(suite: &'a [Stmt], stmt: AnyNodeRef<'a>) -> Option<Self> {
62 let position = suite
63 .iter()
64 .position(|sibling| AnyNodeRef::ptr_eq(sibling.into(), stmt))?;
65
66 Some(EnclosingSuite { suite, position })
67 }
68
69 pub fn next_sibling(&self) -> Option<&'a Stmt> {
70 self.suite.get(self.position + 1)
71 }
72
73 pub fn next_siblings(&self) -> &'a [Stmt] {
74 self.suite.get(self.position + 1..).unwrap_or_default()
75 }
76
77 pub fn previous_sibling(&self) -> Option<&'a Stmt> {
78 self.suite.get(self.position.checked_sub(1)?)
79 }
80}
81
82impl std::ops::Deref for EnclosingSuite<'_> {
83 type Target = [Stmt];
84
85 fn deref(&self) -> &Self::Target {
86 self.suite
87 }
88}