use std::fs;
use dup_defs_core::{Language, LineMap, ModuleDef};
use rayon::prelude::*;
use ruff_python_ast::{Expr, Parameters, Stmt};
use ruff_python_parser::parse_module;
fn count_loc(text: &str) -> usize {
text.lines().filter(|l| !l.trim().is_empty()).count()
}
fn count_args(p: &Parameters) -> usize {
p.posonlyargs.len()
+ p.args.len()
+ p.kwonlyargs.len()
+ usize::from(p.vararg.is_some())
+ usize::from(p.kwarg.is_some())
}
fn is_upper(name: &str) -> bool {
!name.is_empty()
&& name.chars().all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || c == '_')
}
fn const_name(stmt: &Stmt) -> Option<String> {
match stmt {
Stmt::Assign(node) => match node.targets.as_slice() {
[Expr::Name(name)] if is_upper(name.id.as_str()) => Some(name.id.as_str().to_owned()),
_ => None,
},
Stmt::AnnAssign(node) => match node.target.as_ref() {
Expr::Name(name) if is_upper(name.id.as_str()) => Some(name.id.as_str().to_owned()),
_ => None,
},
_ => None,
}
}
fn keyword_start(source: &str, range_start: usize, last_decorator_end: Option<usize>) -> usize {
let Some(mut i) = last_decorator_end else { return range_start };
let bytes = source.as_bytes();
loop {
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i < bytes.len() && bytes[i] == b'#' {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
continue;
}
break;
}
i
}
fn is_trivial_function_body(body: &[ruff_python_ast::Stmt]) -> bool {
body.iter().all(|s| match s {
Stmt::Pass(_) => true,
Stmt::Expr(e) => matches!(e.value.as_ref(), Expr::EllipsisLiteral(_) | Expr::StringLiteral(_)),
Stmt::Raise(r) => match r.exc.as_deref() {
Some(Expr::Name(name)) => name.id.as_str() == "NotImplementedError",
Some(Expr::Call(call)) => matches!(call.func.as_ref(), Expr::Name(n) if n.id.as_str() == "NotImplementedError"),
_ => false,
},
Stmt::Return(r) => match r.value.as_deref() {
None => true,
Some(e) => matches!(
e,
Expr::NoneLiteral(_)
| Expr::BooleanLiteral(_)
| Expr::NumberLiteral(_)
| Expr::StringLiteral(_)
| Expr::BytesLiteral(_)
| Expr::EllipsisLiteral(_)
| Expr::Name(_)
),
},
_ => false,
})
}
fn method_receiver_strip_range(source: &str, params: &Parameters) -> Option<(usize, usize)> {
if !params.posonlyargs.is_empty() && params.posonlyargs.len() == 1 && params.args.is_empty() {
let only = ¶ms.posonlyargs[0];
let n = only.parameter.name.id.as_str();
if n == "self" || n == "cls" {
return None;
}
}
let first = params.posonlyargs.first().or_else(|| params.args.first())?;
let name = first.parameter.name.id.as_str();
if name != "self" && name != "cls" {
return None;
}
let param_start = usize::from(first.parameter.range.start());
let after_param = usize::from(first.range.end());
let bytes = source.as_bytes();
let mut i = after_param;
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i < bytes.len() && bytes[i] == b',' {
i += 1;
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
Some((param_start, i))
} else {
Some((param_start, after_param))
}
}
fn property_accessor_suffix(decorators: &[ruff_python_ast::Decorator]) -> Option<&'static str> {
for d in decorators {
if let Expr::Attribute(attr) = &d.expression {
match attr.attr.id.as_str() {
"setter" => return Some("setter"),
"deleter" => return Some("deleter"),
"getter" => return Some("getter"),
_ => {}
}
}
}
None
}
pub(crate) fn classify(source: &str, stmt: &Stmt) -> Option<(&'static str, String, usize, usize, usize)> {
match stmt {
Stmt::FunctionDef(node) => {
if is_trivial_function_body(&node.body) {
return None;
}
let deco_end = node.decorator_list.last().map(|d| usize::from(d.range.end()));
let start = keyword_start(source, usize::from(node.range.start()), deco_end);
Some((
"functions",
node.name.id.as_str().to_owned(),
start,
usize::from(node.range.end()),
count_args(&node.parameters),
))
}
Stmt::ClassDef(node) => {
let deco_end = node.decorator_list.last().map(|d| usize::from(d.range.end()));
let start = keyword_start(source, usize::from(node.range.start()), deco_end);
Some(("classes", node.name.id.as_str().to_owned(), start, usize::from(node.range.end()), 0))
}
Stmt::TypeAlias(node) => match node.name.as_ref() {
Expr::Name(name) => Some((
"type-aliases",
name.id.as_str().to_owned(),
usize::from(node.range.start()),
usize::from(node.range.end()),
0,
)),
_ => None,
},
Stmt::Assign(node) => const_name(stmt)
.map(|name| ("constants", name, usize::from(node.range.start()), usize::from(node.range.end()), 0)),
Stmt::AnnAssign(node) => const_name(stmt)
.map(|name| ("constants", name, usize::from(node.range.start()), usize::from(node.range.end()), 0)),
_ => None,
}
}
fn class_method_defs(source: &str, stmt: &Stmt, lines: &LineMap, file: &str, parent_chain: &str) -> Vec<ModuleDef> {
let Stmt::ClassDef(class) = stmt else { return Vec::new() };
let class_name = class.name.id.as_str();
let parent = if parent_chain.is_empty() { class_name.to_owned() } else { format!("{parent_chain}.{class_name}") };
let mut out = Vec::new();
for inner in &class.body {
match inner {
Stmt::FunctionDef(node) => {
if is_trivial_function_body(&node.body) {
continue;
}
let deco_end = node.decorator_list.last().map(|d| usize::from(d.range.end()));
let start = keyword_start(source, usize::from(node.range.start()), deco_end);
let end = usize::from(node.range.end());
let (line, col) = lines.loc0(start);
let method_name = node.name.id.as_str();
let name = match property_accessor_suffix(&node.decorator_list) {
Some(role) => format!("{parent}.{method_name}.{role}"),
None => format!("{parent}.{method_name}"),
};
let text_orig = source[start..end].to_owned();
let loc = count_loc(&text_orig);
let args = count_args(&node.parameters);
let text = match method_receiver_strip_range(source, &node.parameters) {
Some((rs, re)) if rs >= start && re <= end => {
let mut t = String::with_capacity(end - start - (re - rs));
t.push_str(&source[start..rs]);
t.push_str(&source[re..end]);
t
}
_ => text_orig.clone(),
};
out.push(ModuleDef {
kind: "methods".to_owned(),
name,
file: file.to_owned(),
line,
col,
text,
text_orig,
loc,
args,
lang: Language::Python,
});
}
Stmt::ClassDef(_) => {
out.extend(class_method_defs(source, inner, lines, file, &parent));
}
_ => {}
}
}
out
}
fn module_defs_from(source: &str, file: &str) -> Vec<ModuleDef> {
let Ok(parsed) = parse_module(source) else { return Vec::new() };
let module = parsed.into_syntax();
let lines = LineMap::new(source);
let mut defs: Vec<ModuleDef> = Vec::new();
for stmt in &module.body {
if matches!(stmt, Stmt::ClassDef(_)) {
defs.extend(class_method_defs(source, stmt, &lines, file, ""));
}
let Some((kind, name, start, end, args)) = classify(source, stmt) else { continue };
let (line, col) = lines.loc0(start);
let text = source[start..end].to_owned();
let loc = count_loc(&text);
let text_orig = text.clone();
defs.push(ModuleDef {
kind: kind.to_owned(),
name,
file: file.to_owned(),
line,
col,
text,
text_orig,
loc,
args,
lang: Language::Python,
});
}
defs
}
fn module_defs_in(file: &str) -> Vec<ModuleDef> {
match fs::read_to_string(file) {
Ok(source) => module_defs_from(&source, file),
Err(_) => Vec::new(),
}
}
#[must_use]
pub fn find_module_defs(files: &[String]) -> Vec<ModuleDef> {
let per_file: Vec<Vec<ModuleDef>> = files.par_iter().map(|f| module_defs_in(f)).collect();
per_file.into_iter().flatten().collect()
}
#[cfg(test)]
mod tests {
use super::module_defs_from;
fn triples(src: &str) -> Vec<(String, String, String)> {
module_defs_from(src, "<test>").into_iter().map(|d| (d.kind, d.name, d.text)).collect()
}
#[test]
fn finds_top_level_kinds_and_class_methods() {
let src = "MAX = 5\nlower = 1\n\ntype Ids = list[int]\n\n\ndef top():\n def nested():\n pass\n return 1\n\n\nclass C:\n def method(self):\n return self.x + 1\n";
let got = triples(src);
let kinds: Vec<&str> = got.iter().map(|(k, _, _)| k.as_str()).collect();
let names: Vec<&str> = got.iter().map(|(_, n, _)| n.as_str()).collect();
assert!(names.contains(&"MAX") && names.contains(&"Ids"));
assert!(names.contains(&"top") && names.contains(&"C"));
assert!(names.contains(&"C.method"));
assert!(!names.contains(&"lower") && !names.contains(&"nested") && !names.contains(&"method"));
assert_eq!(kinds.iter().filter(|k| **k == "functions").count(), 1);
assert_eq!(kinds.iter().filter(|k| **k == "classes").count(), 1);
assert_eq!(kinds.iter().filter(|k| **k == "methods").count(), 1);
}
#[test]
fn class_methods_emitted_with_qualified_names_and_methods_kind() {
let src = "class Foo:\n def __init__(self, x):\n self.x = x\n\n async def fetch(self):\n return self.x\n";
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert!(methods.contains(&"Foo.__init__"), "got methods: {methods:?}");
assert!(methods.contains(&"Foo.fetch"), "got methods: {methods:?}");
let init = got.iter().find(|(_, n, _)| n == "Foo.__init__").expect("init method");
assert!(init.2.starts_with("def "), "method text should start at def, got: {:?}", init.2);
}
#[test]
fn nested_class_methods_use_chained_parent_names() {
let src = "class Outer:\n def outer_m(self):\n return self.x + 1\n\n class Inner:\n def inner_m(self):\n return self.x + 2\n\n class Deep:\n def deep_m(self):\n return self.x + 3\n";
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert!(methods.contains(&"Outer.outer_m"), "got methods: {methods:?}");
assert!(methods.contains(&"Outer.Inner.inner_m"), "got methods: {methods:?}");
assert!(methods.contains(&"Outer.Inner.Deep.deep_m"), "got methods: {methods:?}");
}
#[test]
fn single_token_return_dispatch_overrides_are_skipped() {
let src = concat!(
"class A:\n",
" def is_x(self) -> bool:\n",
" return False\n",
" def default(self):\n",
" return None\n",
" def name(self) -> str:\n",
" return \"a\"\n",
" def num(self) -> int:\n",
" return 0\n",
" def empty(self):\n",
" return\n",
" def self_(self):\n",
" return self\n",
" def get_x(self):\n",
" return self._x\n",
" def sources(self):\n",
" return [self._x]\n",
" def call(self):\n",
" return self.parent.fn()\n",
);
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
for skipped in ["A.is_x", "A.default", "A.name", "A.num", "A.empty", "A.self_"] {
assert!(!methods.contains(&skipped), "{skipped} should be skipped, got: {methods:?}");
}
for kept in ["A.get_x", "A.sources", "A.call"] {
assert!(methods.contains(&kept), "{kept} should be kept, got: {methods:?}");
}
}
#[test]
fn raise_not_implemented_stubs_are_skipped() {
let src = concat!(
"class IFoo:\n",
" def do(self, x: int) -> int:\n",
" raise NotImplementedError\n",
"\n",
" def go(self, x: int) -> int:\n",
" raise NotImplementedError('subclass me')\n",
"\n",
" def real(self, x: int) -> int:\n",
" return x + 1\n",
);
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert_eq!(methods, vec!["IFoo.real"], "got: {methods:?}");
}
#[test]
fn overload_and_abstract_stubs_are_skipped_real_impl_kept() {
let src = concat!(
"from typing import overload\n",
"from abc import abstractmethod\n",
"\n",
"class C:\n",
" @overload\n",
" def foo(self, x: int) -> int: ...\n",
" @overload\n",
" def foo(self, x: str) -> str: ...\n",
" def foo(self, x):\n",
" return x + 1\n",
"\n",
" @abstractmethod\n",
" def bar(self):\n",
" \"\"\"abstract.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def baz(self):\n",
" pass\n",
"\n",
" def qux(self):\n",
" ...\n",
);
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert_eq!(methods, vec!["C.foo"], "expected only the real impl, got: {methods:?}");
let foo = got.iter().find(|(_, n, _)| n == "C.foo").expect("real foo");
assert!(foo.2.contains("return x + 1"), "expected the real impl, got: {:?}", foo.2);
}
#[test]
fn loc_and_args_are_populated_from_original_source() {
let src = concat!(
"def free(a, b, *, c=3):\n",
" x = a + b\n",
" y = x * c\n",
" return y\n",
"\n",
"class C:\n",
" def method(self, x, y):\n",
" if x > y:\n",
" return x\n",
" return y\n",
);
let defs = super::module_defs_from(src, "<test>");
let free = defs.iter().find(|d| d.name == "free").expect("free fn");
assert_eq!(free.loc, 4, "free loc: {}", free.loc);
assert_eq!(free.args, 3, "free args: {}", free.args);
let method = defs.iter().find(|d| d.name == "C.method").expect("method");
assert_eq!(method.loc, 4, "method loc: {}", method.loc);
assert_eq!(method.args, 3, "method args (incl self): {}", method.args);
}
#[test]
fn method_receiver_is_stripped_from_text() {
let src = concat!(
"class C:\n",
" def one(self):\n",
" return self.x + 1\n",
"\n",
" def two(self, x):\n",
" return x + 1\n",
"\n",
" @classmethod\n",
" def three(cls, x):\n",
" return x * 2\n",
"\n",
" @staticmethod\n",
" def four(x):\n",
" return x * 3\n",
);
let got = triples(src);
let body_of = |name: &str| {
got.iter().find(|(_, n, _)| n == name).map(|(_, _, t)| t.clone()).expect("method")
};
assert!(body_of("C.one").starts_with("def one():"), "got: {:?}", body_of("C.one"));
assert!(body_of("C.two").starts_with("def two(x):"), "got: {:?}", body_of("C.two"));
assert!(body_of("C.three").starts_with("def three(x):"), "got: {:?}", body_of("C.three"));
assert!(body_of("C.four").starts_with("def four(x):"), "got: {:?}", body_of("C.four"));
}
#[test]
fn property_setter_and_deleter_get_role_suffix() {
let src = concat!(
"class C:\n",
" @property\n",
" def value(self):\n",
" return self._x\n",
"\n",
" @value.setter\n",
" def value(self, v):\n",
" self._x = v\n",
"\n",
" @value.deleter\n",
" def value(self):\n",
" del self._x\n",
);
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert!(methods.contains(&"C.value"), "getter: {methods:?}");
assert!(methods.contains(&"C.value.setter"), "setter: {methods:?}");
assert!(methods.contains(&"C.value.deleter"), "deleter: {methods:?}");
}
#[test]
fn property_with_real_logic_is_kept() {
let src = concat!(
"class C:\n",
" @property\n",
" def value(self):\n",
" if self._cached is None:\n",
" self._cached = self._compute()\n",
" return self._cached\n",
"\n",
" @value.setter\n",
" def value(self, v):\n",
" self._cached = v\n",
" self._dirty = True\n",
);
let got = triples(src);
let names: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert!(names.contains(&"C.value"), "getter: {names:?}");
assert!(names.contains(&"C.value.setter"), "setter: {names:?}");
}
#[test]
fn class_hidden_inside_function_does_not_surface_methods() {
let src = "def factory():\n class Hidden:\n def helper(self):\n return 1\n return Hidden\n";
let got = triples(src);
let methods: Vec<&str> =
got.iter().filter(|(k, _, _)| k == "methods").map(|(_, n, _)| n.as_str()).collect();
assert!(methods.is_empty(), "no methods expected, got: {methods:?}");
}
#[test]
fn decorated_method_text_excludes_decorators() {
let src = "class C:\n @staticmethod\n def helper(x):\n return x + 1\n";
let got = triples(src);
let helper = got.iter().find(|(_, n, _)| n == "C.helper").expect("helper method");
assert!(helper.2.starts_with("def "), "decorated method text should start at def, got: {:?}", helper.2);
}
#[test]
fn function_text_excludes_decorators() {
let got = triples("import functools\n\n\n@functools.cache\ndef memo(x):\n return x + 1\n");
let func = got.iter().find(|(k, _, _)| k == "functions").expect("a function");
assert!(func.2.starts_with("def "), "text should start at def, got: {:?}", func.2);
}
#[test]
fn pep695_and_modern_syntax_file_is_scanned() {
let src = "type Alias = list[int]\n\n\ndef worker[T](x: T) -> T:\n msg = f\"got {x['k']}\"\n return x\n\n\nclass Repo[T]:\n pass\n";
let names: Vec<String> = module_defs_from(src, "<test>").into_iter().map(|d| d.name).collect();
assert!(names.contains(&"Alias".to_owned()), "type alias missing: {names:?}");
assert!(names.contains(&"worker".to_owned()), "generic fn missing: {names:?}");
assert!(names.contains(&"Repo".to_owned()), "generic class missing: {names:?}");
}
}