use std::path::Path;
use tree_sitter::{Node, Tree};
use crate::ast::parser::ParserPool;
use crate::types::{InheritanceNode, Language};
use crate::TldrResult;
pub fn extract_classes(
source: &str,
file_path: &Path,
parser_pool: &ParserPool,
) -> TldrResult<Vec<InheritanceNode>> {
let tree = parser_pool.parse(source, Language::Python)?;
let mut classes = Vec::new();
extract_classes_from_tree(&tree, source, file_path, &mut classes);
Ok(classes)
}
fn extract_classes_from_tree(
tree: &Tree,
source: &str,
file_path: &Path,
classes: &mut Vec<InheritanceNode>,
) {
let root = tree.root_node();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.kind() == "class_definition" {
if let Some(class) = extract_class_def(&child, source, file_path) {
classes.push(class);
}
}
if child.kind() == "decorated_definition" {
for inner in child.children(&mut child.walk()) {
if inner.kind() == "class_definition" {
if let Some(class) = extract_class_def(&inner, source, file_path) {
classes.push(class);
}
}
}
}
}
}
fn extract_class_def(node: &Node, source: &str, file_path: &Path) -> Option<InheritanceNode> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source.as_bytes()).ok()?.to_string();
let line = node.start_position().row as u32 + 1;
let mut class_node = InheritanceNode::new(
name.clone(),
file_path.to_path_buf(),
line,
Language::Python,
);
if let Some(args) = node.child_by_field_name("superclasses") {
let mut bases = Vec::new();
let mut metaclass = None;
for i in 0..args.child_count() {
if let Some(child) = args.child(i) {
match child.kind() {
"identifier" => {
if let Ok(base_name) = child.utf8_text(source.as_bytes()) {
bases.push(base_name.to_string());
}
}
"attribute" => {
if let Some(base_name) = extract_attribute_name(&child, source) {
bases.push(base_name);
}
}
"subscript" => {
if let Some(base_name) = extract_subscript_name(&child, source) {
bases.push(base_name);
}
}
"keyword_argument" => {
if let Some((key, value)) = extract_keyword_arg(&child, source) {
if key == "metaclass" {
metaclass = Some(value);
}
}
}
_ => {}
}
}
}
class_node.bases = bases;
class_node.metaclass = metaclass;
}
if has_abstractmethod_decorator(node, source) || class_node.bases.contains(&"ABC".to_string()) {
class_node.is_abstract = Some(true);
}
if class_node
.bases
.iter()
.any(|b| b == "Protocol" || b.ends_with(".Protocol"))
{
class_node.protocol = Some(true);
}
Some(class_node)
}
fn extract_attribute_name(node: &Node, source: &str) -> Option<String> {
let attr = node.child_by_field_name("attribute")?;
attr.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string())
}
fn extract_subscript_name(node: &Node, source: &str) -> Option<String> {
let value = node.child_by_field_name("value")?;
match value.kind() {
"identifier" => value
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string()),
"attribute" => extract_attribute_name(&value, source),
_ => None,
}
}
fn extract_keyword_arg(node: &Node, source: &str) -> Option<(String, String)> {
let name = node.child_by_field_name("name")?;
let value = node.child_by_field_name("value")?;
let key = name.utf8_text(source.as_bytes()).ok()?;
let val = match value.kind() {
"identifier" => value.utf8_text(source.as_bytes()).ok()?.to_string(),
"attribute" => extract_attribute_name(&value, source)?,
_ => return None,
};
Some((key.to_string(), val))
}
fn has_abstractmethod_decorator(class_node: &Node, source: &str) -> bool {
if let Some(body) = class_node.child_by_field_name("body") {
for i in 0..body.child_count() {
if let Some(child) = body.child(i) {
if child.kind() == "decorated_definition" {
for j in 0..child.child_count() {
if let Some(decorator) = child.child(j) {
if decorator.kind() == "decorator" {
if let Ok(text) = decorator.utf8_text(source.as_bytes()) {
if text.contains("abstractmethod") {
return true;
}
}
}
}
}
}
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn parse_and_extract(source: &str) -> Vec<InheritanceNode> {
let pool = ParserPool::new();
extract_classes(source, &PathBuf::from("test.py"), &pool).unwrap()
}
#[test]
fn test_simple_class() {
let source = r#"
class Animal:
pass
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Animal");
assert!(classes[0].bases.is_empty());
}
#[test]
fn test_single_inheritance() {
let source = r#"
class Animal:
pass
class Dog(Animal):
pass
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 2);
let dog = classes.iter().find(|c| c.name == "Dog").unwrap();
assert_eq!(dog.bases, vec!["Animal"]);
}
#[test]
fn test_multiple_inheritance() {
let source = r#"
class User(Base, TimestampMixin, AuditMixin):
pass
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(
classes[0].bases,
vec!["Base", "TimestampMixin", "AuditMixin"]
);
}
#[test]
fn test_abc_detection() {
let source = r#"
from abc import ABC, abstractmethod
class Animal(ABC):
@abstractmethod
def speak(self):
pass
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].is_abstract, Some(true));
assert!(classes[0].bases.contains(&"ABC".to_string()));
}
#[test]
fn test_protocol_detection() {
let source = r#"
from typing import Protocol
class Serializable(Protocol):
def serialize(self) -> dict:
...
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].protocol, Some(true));
}
#[test]
fn test_metaclass_extraction() {
let source = r#"
class Singleton(metaclass=SingletonMeta):
pass
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].metaclass, Some("SingletonMeta".to_string()));
}
#[test]
fn test_generic_base() {
let source = r#"
from typing import Generic, TypeVar
T = TypeVar('T')
class Container(Generic[T]):
pass
"#;
let classes = parse_and_extract(source);
let container = classes.iter().find(|c| c.name == "Container").unwrap();
assert!(container.bases.contains(&"Generic".to_string()));
}
}